2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #ifndef __NNFW_CKER_EIGEN_EIGEN_SPATIAL_CONVOLUTIONS_INL_H__
19 #define __NNFW_CKER_EIGEN_EIGEN_SPATIAL_CONVOLUTIONS_INL_H__
21 #include "cker/eigen/eigen_convolution_helpers.h"
23 // Note this header is used in both TF and TFLite.
30 // WARNING: Most of the code here implicitly assumes that the matrix is in
31 // ColMajor layout. This is guaranteed by the tensor contraction (see
32 // TensorContraction.h).
34 // Inside Eigen a tensor contraction is represented by a matrix multiplication.
35 // We don't want to actually extract image patches and reshape the result into
36 // a matrix (this involves allocating huge extra memory), so the patch
37 // extraction and reshape operations are implicit.
39 // TensorContractionInputMapper takes a matrix index and returns the coefficient
40 // (or the packet) of the "virtual tensor", that would be at that index if we
41 // were to actually reshape the result of patch extraction.
43 // TensorContractionSubMapper provides a similar view into the "virtual matrix"
44 // at the given vertical and horizontal offsets.
46 // "Virtual matrix" dimensions:
47 // *0: kernelChannels * kernelRows * kernelCols;
48 // 1: out_height * out_width; * OTHERS (e.g batches, etc...)
50 // *) extracted patches are continuous in memory (innermost dimension assuming
53 // With this dimensions:
54 // row - offset within a single patch (in code: patchId)
55 // col - index of the extracted patch (in code: patchIndex)
56 // patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions)
58 // TODO(ezhulenev): Consolidate this part of the code with the image patch
59 // extraction code since they are both very similar.
61 template <typename NewDimension, Index Rows, Index Cols, typename ArgType, typename Device,
62 typename Scalar_, typename Index, typename nocontract_t, typename contract_t, int Side,
63 int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
64 class TensorContractionInputMapper<
67 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>,
69 nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
72 typedef Scalar_ Scalar;
74 typedef TensorContractionInputMapper<
77 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>,
79 nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
82 typedef TensorContractionSubMapper<
85 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>,
87 nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
90 typedef SubMapper VectorMapper;
91 typedef SubMapper LinearMapper;
92 typedef typename packet_traits<Scalar>::type Packet;
94 typedef TensorEvaluator<ArgType, Device> TensorEvaluatorT;
97 TensorContractionInputMapper(
98 const TensorEvaluator<
99 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>,
101 const nocontract_t &, const nocontract_t &, const contract_t &, const contract_t &)
102 : m_impl(tensor.impl().impl())
106 if (internal::traits<ArgType>::Layout == ColMajor)
108 patch_depth = tensor.impl().dimensions()[0];
109 patch_rows = tensor.impl().dimensions()[1];
110 m_patch_cols = tensor.impl().dimensions()[2];
111 m_num_patches = tensor.impl().dimensions()[3];
115 const size_t NumDims = tensor.impl().dimensions().size();
116 patch_depth = tensor.impl().dimensions()[NumDims - 1];
117 patch_rows = tensor.impl().dimensions()[NumDims - 2];
118 m_patch_cols = tensor.impl().dimensions()[NumDims - 3];
119 m_num_patches = tensor.impl().dimensions()[NumDims - 4];
122 // Strides for navigating through the single patch.
123 m_patch_row_stride = patch_depth;
124 m_patch_col_stride = patch_rows * m_patch_row_stride;
126 m_patch_row_inflate_strides = tensor.impl().rowInflateStride();
127 m_patch_col_inflate_strides = tensor.impl().colInflateStride();
129 m_colStride = patch_rows;
131 m_outputRows = tensor.impl().outputRows();
132 m_outputCols = tensor.impl().outputCols();
133 m_row_strides = tensor.impl().userRowStride();
134 m_col_strides = tensor.impl().userColStride();
136 m_in_row_strides = tensor.impl().userInRowStride();
137 m_in_col_strides = tensor.impl().userInColStride();
139 if (internal::traits<ArgType>::Layout == ColMajor)
141 m_inputRows = tensor.impl().impl().dimensions()[1];
142 m_inputCols = tensor.impl().impl().dimensions()[2];
146 const int NumDims = tensor.impl().impl().dimensions().size();
147 m_inputRows = tensor.impl().impl().dimensions()[NumDims - 2];
148 m_inputCols = tensor.impl().impl().dimensions()[NumDims - 3];
151 m_rowInputStride = patch_depth;
152 m_colInputStride = patch_depth * m_inputRows;
153 m_patchInputStride = patch_depth * m_inputRows * m_inputCols;
155 m_rowPaddingTop = tensor.impl().rowPaddingTop();
156 m_colPaddingLeft = tensor.impl().colPaddingLeft();
158 m_fastPatchRowStride = internal::TensorIntDivisor<Index>(m_patch_row_stride);
159 m_fastPatchColStride = internal::TensorIntDivisor<Index>(m_patch_col_stride);
160 m_fastInputRowStride = internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides);
161 m_fastInputColStride = internal::TensorIntDivisor<Index>(m_patch_col_inflate_strides);
162 m_fastNumPatches = internal::TensorIntDivisor<Index>(m_num_patches);
163 m_fastColStride = internal::TensorIntDivisor<Index>(m_colStride);
164 m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows);
165 m_fastDimZero = internal::TensorIntDivisor<Index>(patch_depth);
169 TensorContractionInputMapper(const TensorContractionInputMapper &base_mapper)
170 : m_impl(base_mapper.m_impl)
172 m_patch_cols = base_mapper.m_patch_cols;
173 m_num_patches = base_mapper.m_num_patches;
175 m_patch_row_stride = base_mapper.m_patch_row_stride;
176 m_patch_col_stride = base_mapper.m_patch_col_stride;
178 m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides;
179 m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides;
181 m_colStride = base_mapper.m_colStride;
183 m_rowInputStride = base_mapper.m_rowInputStride;
184 m_colInputStride = base_mapper.m_colInputStride;
185 m_patchInputStride = base_mapper.m_patchInputStride;
187 m_inputRows = base_mapper.m_inputRows;
188 m_inputCols = base_mapper.m_inputCols;
190 m_outputRows = base_mapper.m_outputRows;
191 m_outputCols = base_mapper.m_outputCols;
192 m_row_strides = base_mapper.m_row_strides;
193 m_col_strides = base_mapper.m_col_strides;
195 m_in_row_strides = base_mapper.m_in_row_strides;
196 m_in_col_strides = base_mapper.m_in_col_strides;
198 m_rowPaddingTop = base_mapper.m_rowPaddingTop;
199 m_colPaddingLeft = base_mapper.m_colPaddingLeft;
201 m_fastPatchRowStride = base_mapper.m_fastPatchRowStride;
202 m_fastPatchColStride = base_mapper.m_fastPatchColStride;
203 m_fastInputRowStride = base_mapper.m_fastInputRowStride;
204 m_fastInputColStride = base_mapper.m_fastInputColStride;
205 m_fastNumPatches = base_mapper.m_fastNumPatches;
206 m_fastColStride = base_mapper.m_fastColStride;
207 m_fastOutputRows = base_mapper.m_fastOutputRows;
208 m_fastDimZero = base_mapper.m_fastDimZero;
211 // If true, turns off some optimizations for loading packets since the image
212 // patches are "non-standard" such as there are non-trivial strides or
213 // inflations in the input.
215 EIGEN_ALWAYS_INLINE bool nonStandardPatches() const
217 return m_in_row_strides != 1 || m_in_col_strides != 1 || m_patch_row_inflate_strides != 1 ||
218 m_patch_col_inflate_strides != 1;
222 EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const
224 return SubMapper(*this, i, j);
228 EIGEN_STRONG_INLINE LinearMapper getLinearMapper(Index i, Index j) const
230 return LinearMapper(*this, i, j);
234 EIGEN_ALWAYS_INLINE Scalar operator()(Index row) const
236 Index rowIndex, colIndex, otherIndex;
237 computeBaseIndices(0, rowIndex, colIndex, otherIndex);
238 return loadCoeff(row, rowIndex, colIndex, otherIndex);
241 // Load the coefficient at the patchIndex location instead of the usual
243 // m_colIndex, m_otherIndex. This is currently only used by the gpu code.
246 EIGEN_STRONG_INLINE Scalar operator()(Index row, Index patchIndex) const
248 Index rowIndex, colIndex, otherIndex;
249 computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex);
250 return loadCoeff(row, rowIndex, colIndex, otherIndex);
254 EIGEN_ALWAYS_INLINE Packet loadPacket(Index row) const
256 Index rowIndex, colIndex, otherIndex;
257 computeBaseIndices(0, rowIndex, colIndex, otherIndex);
258 return loadPacket(row, rowIndex, colIndex, otherIndex);
261 // Load the packet at the patchIndex location instead of the usual m_rowIndex,
262 // m_colIndex, m_otherIndex. This is currently only used by the gpu code.
264 EIGEN_ALWAYS_INLINE Packet loadPacket(Index row, Index patchIndex) const
266 Index rowIndex, colIndex, otherIndex;
267 computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex);
268 return loadPacket(row, rowIndex, colIndex, otherIndex);
272 EIGEN_ALWAYS_INLINE const TensorEvaluator<ArgType, Device> &impl() const { return m_impl; }
275 EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_rowInputStride; }
277 EIGEN_ALWAYS_INLINE Index patchRows() const { return m_colStride; }
279 EIGEN_ALWAYS_INLINE Index patchCols() const { return m_patch_cols; }
282 friend class TensorContractionSubMapper<
285 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>,
287 nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>;
289 // Load coefficient from a patch specified by the "within patch offset"
290 // (patchId) and the precomputed indices of the first element of the patch.
292 EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index rowIndex, Index colIndex,
293 Index otherIndex) const
295 // Find the offset of the element wrt the location of the first element.
296 const Index patchOffset = patchId / m_fastDimZero;
298 const Index colOffset = patchOffset / m_fastColStride;
299 const Index inputCol = colIndex + colOffset * m_in_col_strides;
300 const Index origInputCol = (m_patch_col_inflate_strides == 1)
302 : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0);
304 const Index rowOffset = patchOffset - colOffset * m_colStride;
305 const Index inputRow = rowIndex + rowOffset * m_in_row_strides;
306 const Index origInputRow = (m_patch_row_inflate_strides == 1)
308 : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0);
309 if (origInputCol < 0 || origInputRow < 0 || origInputCol >= m_inputCols ||
310 origInputRow >= m_inputRows || (inputCol != origInputCol * m_patch_col_inflate_strides) ||
311 (inputRow != origInputRow * m_patch_row_inflate_strides))
315 const Index depth = patchId - patchOffset * patchDepth();
316 const Index inputIndex =
317 depth + origInputRow * m_rowInputStride + origInputCol * m_colInputStride + otherIndex;
318 return m_impl.coeff(inputIndex);
321 // This is the same as loadCoeff(...), but optimized for all `inflate_strides`
322 // and `in_strides` equal to 1 (template specialization without templates).
324 EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index rowIndex, Index colIndex,
325 Index otherIndex) const
327 eigen_assert(!nonStandardPatches());
329 // Find the offset of the element wrt the location of the first element.
330 const Index patchOffset = patchId / m_fastDimZero;
331 const Index colOffset = patchOffset / m_fastColStride;
332 const Index rowOffset = patchOffset - colOffset * m_colStride;
333 const Index inputCol = colIndex + colOffset;
334 const Index inputRow = rowIndex + rowOffset;
335 if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 || inputRow >= m_inputRows)
339 const Index depth = patchId - patchOffset * patchDepth();
340 const Index inputIndex =
341 depth + inputRow * m_rowInputStride + inputCol * m_colInputStride + otherIndex;
342 return m_impl.coeff(inputIndex);
345 // Load packet from a patch specified by the "within patch offset"
346 // (patchId) and the precomputed indices of the first element of the patch.
348 EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index rowIndex, Index colIndex,
349 Index otherIndex) const
351 const Index packetSize = internal::unpacket_traits<Packet>::size;
352 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
353 eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
355 if (nonStandardPatches())
357 return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
359 typedef decltype(m_impl) TensorEvaluatorT;
360 return loadPacketStandard<Packet, TensorEvaluatorT>(patchId, rowIndex, colIndex, otherIndex);
363 // Helper function to load a 'partial' packet - this is the single column
364 // part of a packet that is split across two columns. In the 'partial' packet,
365 // the elements corresponding to the column (specified through colOffset) are
366 // loaded and the rest of the elements are zero-filled into the 'partial'
367 // packet. This function is called from loadPacketStandardFromTwoColumns().
368 // This code path is exercised only when the packet type supports masked load
369 // and when the partial packet load is available in the TensorEvaluator.
371 EIGEN_ALWAYS_INLINE Packet loadPartialPacketStandard(Index rowIndex, Index colIndex,
372 Index otherIndex, Index patchId,
374 const Index patchOffsets[],
375 Index colOffset) const
377 const Index inputCol = colIndex + colOffset;
378 const Index rowOffsets[2] = {patchOffsets[0] - colOffset * m_colStride,
379 patchOffsets[1] - colOffset * m_colStride};
380 const Index inputRows[2] = {rowIndex + rowOffsets[0], rowIndex + rowOffsets[1]};
382 if (inputRows[0] >= m_inputRows || inputRows[1] < 0 || inputCol >= m_inputCols || inputCol < 0)
384 // Partial packet is all zeros
385 return internal::pset1<Packet>(Scalar(0));
387 else if (inputRows[0] >= 0 && inputRows[1] < m_inputRows)
389 // From inputIndex-span[0], we need to load elements starting from index
390 // span[0] all the way upto (and including) span[1].
391 const Index depth = patchId - patchOffsets[0] * patchDepth();
392 const Index inputIndex =
393 depth + inputRows[0] * m_rowInputStride + inputCol * m_colInputStride + otherIndex;
394 return m_impl.template partialPacket<Packet>(inputIndex - span[0],
395 mask<Packet>(span[0], span[1] + 1));
399 // Using slow path for this partial packet.
400 // We need to load elements starting from index span[0] all the way upto
401 // (and including) span[1]. We split this load into 3 parts:
402 // 0 : span[0]-1 - Zeros will be loaded for these indices
403 // span[0] : span[1] - Elements will be loaded here for these indices
404 // span[1]+1 : packetSize-1 - Zeross will be loaded for these indices
405 const Index packetSize = internal::unpacket_traits<Packet>::size;
407 typename internal::remove_const<Scalar>::type values[packetSize];
408 for (int i = 0; i < span[0]; ++i)
409 values[i] = Scalar(0);
410 for (int i = span[0]; i < span[1] + 1; ++i)
411 values[i] = loadCoeff(patchId - span[0] + i, rowIndex, colIndex, otherIndex);
412 for (int i = span[1] + 1; i < packetSize; ++i)
413 values[i] = Scalar(0);
414 return internal::pload<Packet>(values);
418 // Helper function to load a packet that is split across two columns.
419 // If required, this function is called from loadPacketStandard() when the
420 // packet type supports masked load and when the partial packet load is
421 // available in the TensorEvaluator.
423 EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromTwoColumns(Index patchId, Index rowIndex,
424 Index colIndex, Index otherIndex,
425 const Index patchOffsets[],
426 const Index colOffsets[]) const
428 eigen_assert(colOffsets[1] == colOffsets[0] + 1);
429 const Index packetSize = internal::unpacket_traits<Packet>::size;
431 // Packet to load will be split into 2 parts where each part spans a single
432 // column. First determine where to split.
433 const Index patchIdSplit = ((colOffsets[1] * m_colStride) * m_rowInputStride) - 1;
434 const Index patchOffsetSplit = patchIdSplit / m_fastDimZero;
436 // patchIds[i]: patchId corresponding to partial packet i
437 // spans[i]: Start and end indices corresponding to the elements
438 // to be loaded for partial packet i
439 // patchOffsets2Cols[i]: patchOffsets corresponding to partial packet i
440 const Index patchIds[2] = {patchId, patchIdSplit + 1};
441 const Index spans[2][2] = {{0, patchIdSplit - patchId},
442 {patchIdSplit - patchId + 1, packetSize - 1}};
443 const Index patchOffsets2Cols[2][2] = {{patchOffsets[0], patchOffsetSplit},
444 {patchOffsetSplit + 1, patchOffsets[1]}};
446 // Load partial packets and do bit-wise OR to generate required packet
447 return internal::por<Packet>(
448 loadPartialPacketStandard(rowIndex, colIndex, otherIndex, patchIds[0], spans[0],
449 patchOffsets2Cols[0], colOffsets[0]),
450 loadPartialPacketStandard(rowIndex, colIndex, otherIndex, patchIds[1], spans[1],
451 patchOffsets2Cols[1], colOffsets[1]));
454 // Helper function to load a packet that is present in a single columns.
455 // If required, this function is called from loadPacketStandard().
457 EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromSingleColumn(Index patchId, Index rowIndex,
458 Index colIndex, Index otherIndex,
459 const Index patchOffsets[],
460 const Index colOffsets[],
461 const Index inputCols[]) const
463 eigen_assert(colOffsets[0] == colOffsets[1]);
464 const Index rowOffsets[2] = {patchOffsets[0] - colOffsets[0] * m_colStride,
465 patchOffsets[1] - colOffsets[1] * m_colStride};
466 eigen_assert(rowOffsets[0] <= rowOffsets[1]);
467 const Index inputRows[2] = {rowIndex + rowOffsets[0], rowIndex + rowOffsets[1]};
469 if (inputRows[0] >= m_inputRows || inputRows[1] < 0)
472 return internal::pset1<Packet>(Scalar(0)); // all zeros
475 if (inputRows[0] >= 0 && inputRows[1] < m_inputRows)
478 const Index depth = patchId - patchOffsets[0] * patchDepth();
479 const Index inputIndex =
480 depth + inputRows[0] * m_rowInputStride + inputCols[0] * m_colInputStride + otherIndex;
481 return m_impl.template packet<Unaligned>(inputIndex);
483 return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
486 // Load standard packet from a patch specified by the "within patch offset"
487 // (patchId) and the precomputed indices of the first element of the patch.
488 // This function will be called if partial packet loading is not available
489 // for the TensorEvaluator or if the packet type does not support masked
491 template <typename PacketT, typename TensorEvaluatorT>
492 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if<
493 !TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value, PacketT>::type
494 loadPacketStandard(Index patchId, Index rowIndex, Index colIndex, Index otherIndex) const
496 const Index packetSize = internal::unpacket_traits<Packet>::size;
497 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
498 eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
500 eigen_assert(!nonStandardPatches());
502 if ((patchDepth() % packetSize) == 0)
504 return loadPacketFast(patchId, rowIndex, colIndex, otherIndex);
507 // Offsets and input calculation here are identical to
508 // loadCoeffStandard(...), but repeated twice.
509 const Index patchOffsets[2] = {patchId / m_fastDimZero,
510 (patchId + packetSize - 1) / m_fastDimZero};
511 const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
512 patchOffsets[1] / m_fastColStride};
513 const Index inputCols[2] = {colIndex + colOffsets[0], colIndex + colOffsets[1]};
515 if (inputCols[0] >= m_inputCols || inputCols[1] < 0)
518 return internal::pset1<Packet>(Scalar(0));
520 if (inputCols[0] == inputCols[1])
522 return loadPacketStandardFromSingleColumn(patchId, rowIndex, colIndex, otherIndex,
523 patchOffsets, colOffsets, inputCols);
525 return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
528 // Load standard packet from a patch specified by the "within patch offset"
529 // (patchId) and the precomputed indices of the first element of the patch.
530 // This function will be called if partial packet loading is available for
531 // the TensorEvaluator and if the packet type supports masked load.
532 // The only difference between this and the other case is that if the packet
533 // to load is split across two columns, then in this case instead of going to
534 // the slow (element-by-element) load, we load two packets - each containing
535 // elements from one of the columns (rest of the elements of the packets are
536 // zeroes), and then combine these two packets to generate the required
537 // packet. The idea is to enable fast load (if possible) of these 'partial'
539 template <typename PacketT, typename TensorEvaluatorT>
540 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if<
541 TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value, PacketT>::type
542 loadPacketStandard(Index patchId, Index rowIndex, Index colIndex, Index otherIndex) const
544 const Index packetSize = internal::unpacket_traits<PacketT>::size;
545 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
546 eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
548 eigen_assert(!nonStandardPatches());
550 if ((patchDepth() % packetSize) == 0)
552 return loadPacketFast(patchId, rowIndex, colIndex, otherIndex);
555 // Offsets and input calculation here are identical to
556 // loadCoeffStandard(...), but repeated twice.
557 const Index patchOffsets[2] = {patchId / m_fastDimZero,
558 (patchId + packetSize - 1) / m_fastDimZero};
559 const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
560 patchOffsets[1] / m_fastColStride};
561 const Index inputCols[2] = {colIndex + colOffsets[0], colIndex + colOffsets[1]};
563 if (inputCols[0] >= m_inputCols || inputCols[1] < 0)
566 return internal::pset1<PacketT>(Scalar(0));
568 if (inputCols[0] == inputCols[1])
570 return loadPacketStandardFromSingleColumn(patchId, rowIndex, colIndex, otherIndex,
571 patchOffsets, colOffsets, inputCols);
573 if (inputCols[1] == inputCols[0] + 1)
575 return loadPacketStandardFromTwoColumns(patchId, rowIndex, colIndex, otherIndex, patchOffsets,
578 return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
582 EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index patchId, Index rowIndex, Index colIndex,
583 Index otherIndex) const
585 const Index packetSize = internal::unpacket_traits<Packet>::size;
586 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
587 eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
589 eigen_assert(!nonStandardPatches());
590 eigen_assert((patchDepth() % packetSize) == 0);
591 // Find the offset of the element wrt the location of the first element.
592 const Index patchOffset = patchId / m_fastDimZero;
593 eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset);
595 const Index colOffset = patchOffset / m_fastColStride;
596 const Index rowOffset = patchOffset - colOffset * m_colStride;
597 const Index inputCol = colIndex + colOffset;
598 const Index inputRow = rowIndex + rowOffset;
599 if (inputCol < 0 || inputRow < 0 || inputCol >= m_inputCols || inputRow >= m_inputRows)
602 return internal::pset1<Packet>(Scalar(0));
605 const Index depth = patchId - patchOffset * patchDepth();
606 const Index inputIndex =
607 depth + inputRow * m_rowInputStride + inputCol * m_colInputStride + otherIndex;
608 return m_impl.template packet<Unaligned>(inputIndex);
611 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet packetWithPossibleZero(Index patchId, Index rowIndex,
613 Index otherIndex) const
615 const int packetSize = internal::unpacket_traits<Packet>::size;
617 typename internal::remove_const<Scalar>::type values[packetSize];
618 for (int i = 0; i < packetSize; ++i)
620 values[i] = loadCoeff(patchId + i, rowIndex, colIndex, otherIndex);
622 Packet rslt = internal::pload<Packet>(values);
626 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void
627 computeBaseIndices(Index patchIndex, Index &rowIndex, Index &colIndex, Index &otherIndex) const
629 const size_t NumInputDims =
630 array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
631 otherIndex = (NumInputDims == 3) ? 0 : patchIndex / m_fastNumPatches;
632 const Index patch2DIndex =
633 (NumInputDims == 3) ? patchIndex : (patchIndex - otherIndex * m_num_patches);
634 otherIndex *= m_patchInputStride;
635 colIndex = patch2DIndex / m_fastOutputRows;
636 rowIndex = patch2DIndex - colIndex * m_outputRows;
637 colIndex = colIndex * m_col_strides - m_colPaddingLeft;
638 rowIndex = rowIndex * m_row_strides - m_rowPaddingTop;
641 Index m_patch_cols; // number of columns in the patch
642 Index m_num_patches; // number of patches to extract.
644 // Strides for navigating through the single patch.
645 Index m_patch_row_stride;
646 Index m_patch_col_stride;
647 internal::TensorIntDivisor<Index> m_fastPatchRowStride;
648 internal::TensorIntDivisor<Index> m_fastPatchColStride;
650 Index m_patch_row_inflate_strides; // the strides for row inflation in the
652 Index m_patch_col_inflate_strides; // the strides for col inflation in the
654 // Fast representation of inflation strides.
655 internal::TensorIntDivisor<Index> m_fastInputRowStride;
656 internal::TensorIntDivisor<Index> m_fastInputColStride;
660 internal::TensorIntDivisor<Index> m_fastNumPatches;
661 internal::TensorIntDivisor<Index> m_fastColStride;
663 Index m_rowInputStride; // row stride in the input tensor
664 Index m_colInputStride; // col stride in the input tensor
665 Index m_patchInputStride; // patch stride in the input tensor
667 Index m_inputRows; // Number of rows in the input tensor
668 Index m_inputCols; // Number of cols in the input tensor
670 Index m_outputRows; // Number of convolution output rows
671 Index m_outputCols; // Number of convolution output column
673 Index m_row_strides; // User specified row stride
674 Index m_col_strides; // User specified col stride
676 Index m_in_row_strides; // User specified input row stride
677 Index m_in_col_strides; // User specified input col stride
679 Index m_rowPaddingTop; // Row padding
680 Index m_colPaddingLeft; // Column padding
682 internal::TensorIntDivisor<Index> m_fastOutputRows;
683 internal::TensorIntDivisor<Index> m_fastDimZero;
685 const TensorEvaluator<ArgType, Device> m_impl;
688 template <typename NewDimension, Index Rows, Index Cols, typename ArgType, typename Device,
689 typename Scalar, typename Index, typename nocontract_t, typename contract_t, int Side,
690 int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
691 class TensorContractionSubMapper<
694 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>,
696 nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
699 typedef typename packet_traits<Scalar>::type Packet;
700 typedef typename packet_traits<Scalar>::half HalfPacket;
702 typedef TensorContractionInputMapper<
705 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>,
707 nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
710 typedef TensorContractionSubMapper<
713 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>,
715 nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
718 typedef Self LinearMapper;
720 typedef typename ParentMapper::TensorEvaluatorT TensorEvaluatorT;
722 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(const ParentMapper &base_mapper,
725 : m_depth_offset(vert_offset), m_col_offset(horiz_offset), m_base_mapper(base_mapper)
727 m_base_mapper.computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex, m_otherIndex);
729 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(const Self &base_mapper,
732 : m_depth_offset(vert_offset + base_mapper.m_depth_offset),
733 m_col_offset(horiz_offset + base_mapper.m_col_offset),
734 m_base_mapper(base_mapper.m_base_mapper)
736 m_base_mapper.computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex, m_otherIndex);
738 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const
740 return m_base_mapper.loadCoeff(i + m_depth_offset, m_rowIndex, m_colIndex, m_otherIndex);
742 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, Index j) const
744 return m_base_mapper(i + m_depth_offset, j + m_col_offset);
747 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const
749 return m_base_mapper.loadPacket(i + m_depth_offset, m_rowIndex, m_colIndex, m_otherIndex);
751 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const
753 return m_base_mapper.template loadPacket<Alignment>(i + m_depth_offset, j + m_col_offset);
755 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar loadCoeffStandard(Index i) const
757 return m_base_mapper.loadCoeffStandard(i + m_depth_offset, m_rowIndex, m_colIndex,
761 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index i) const
763 return m_base_mapper.loadPacketFast(i + m_depth_offset, m_rowIndex, m_colIndex, m_otherIndex);
765 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketStandard(Index i) const
767 typedef decltype(m_base_mapper.m_impl) TensorEvaluatorT;
768 return m_base_mapper.template loadPacketStandard<Packet, TensorEvaluatorT>(
769 i + m_depth_offset, m_rowIndex, m_colIndex, m_otherIndex);
771 template <typename Packet> EIGEN_DEVICE_FUNC bool aligned(Index) const { return false; }
774 EIGEN_ALWAYS_INLINE bool nonStandardPatches() const { return m_base_mapper.nonStandardPatches(); }
776 // Max(Col|Row|Depth): compute the upper limit for the column, row and depth
777 // index respectively that fits into the peeled_k elements starting at
781 EIGEN_ALWAYS_INLINE Index maxCol(const Index peeled_k) const
783 const Index max_col =
784 (m_depth_offset + (peeled_k == 0 ? 0 : peeled_k - 1)) / fastPatchColStride();
785 return std::min<Index>(1 + max_col, patchCols());
789 EIGEN_ALWAYS_INLINE Index maxRow(const Index peeled_k, const Index col) const
791 const Index max_row =
792 (m_depth_offset + (peeled_k == 0 ? 0 : peeled_k - 1) - col * patchColStride()) /
793 fastPatchRowStride();
794 return std::min<Index>(1 + max_row, patchRows());
798 EIGEN_ALWAYS_INLINE Index maxDepth(const Index peeled_k, const Index col, Index row) const
800 const Index max_depth = m_depth_offset + peeled_k - //
801 col * patchColStride() - //
802 row * patchRowStride();
803 return std::min<Index>(max_depth, patchDepth());
806 // MaxDepth uses only the remaining number of elements in the peeled_k.
808 EIGEN_ALWAYS_INLINE Index maxDepth(const Index num_elements, const Index start_depth) const
810 return std::min<Index>(start_depth + num_elements, patchDepth());
813 // Every register matters in this code, so sometimes to prevent register
814 // spilling, instead of the variable that you would expect to see, we use
815 // another one, that is guaranteed to have the same value. E.g. patch depth is
816 // always the same as input depth, and it's also the same as input row stride.
817 // Bunch of other parameters have similar relations.
819 typedef internal::TensorIntDivisor<Index> IndexDivisor;
822 EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_base_mapper.m_rowInputStride; }
824 EIGEN_ALWAYS_INLINE Index patchRows() const { return m_base_mapper.m_colStride; }
826 EIGEN_ALWAYS_INLINE Index patchCols() const { return m_base_mapper.m_patch_cols; }
829 EIGEN_ALWAYS_INLINE Index patchRowStride() const
831 eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride &&
832 "Patch depth must be equal to patch row stride.");
836 EIGEN_ALWAYS_INLINE Index patchColStride() const { return m_base_mapper.m_patch_col_stride; }
839 EIGEN_ALWAYS_INLINE IndexDivisor fastPatchRowStride() const
841 eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride &&
842 "Patch depth must be equal to patch row stride.");
843 return m_base_mapper.m_fastDimZero; // patch_depth
846 EIGEN_ALWAYS_INLINE IndexDivisor fastPatchColStride() const
848 return m_base_mapper.m_fastPatchColStride;
852 EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth, const Index baseIndex) const
854 const Index inputIndex = depth + baseIndex;
855 return m_base_mapper.m_impl.template packet<Unaligned>(inputIndex);
858 EIGEN_ALWAYS_INLINE Scalar coeffNoPadding(const Index depth, const Index baseIndex) const
860 const Index inputIndex = depth + baseIndex;
861 return m_base_mapper.m_impl.coeff(inputIndex);
863 template <typename PacketT = Packet>
864 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if<
865 TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value, PacketT>::type
866 partialPacketNoPadding(const Index depth, const Index baseIndex, Index num_coeffs) const
868 const Index inputIndex = depth + baseIndex;
869 return m_base_mapper.m_impl.template partialPacket<PacketT>(inputIndex,
870 mask<PacketT>(0, num_coeffs));
873 EIGEN_ALWAYS_INLINE bool hasPadding() const
875 // TODO(ezhulenev): It does seems that for inflated filter it's still
876 // possible to guarantee "no padding or skipping" for non-standard packing.
877 if (nonStandardPatches())
880 // Non zero padding before.
881 if (m_base_mapper.m_rowPaddingTop > 0)
883 if (m_base_mapper.m_colPaddingLeft > 0)
886 // Non zero padding after in rows.
887 const Index last_row = (m_base_mapper.m_outputRows - 1) * m_base_mapper.m_row_strides;
888 if (last_row + (patchRows() - 1) >= m_base_mapper.m_inputRows)
891 // Non zero padding after in cols.
892 const Index last_col = (m_base_mapper.m_outputCols - 1) * m_base_mapper.m_col_strides;
893 if (last_col + (patchCols() - 1) >= m_base_mapper.m_inputCols)
899 EIGEN_ALWAYS_INLINE bool padRow(const Index row) const
901 const Index r = m_rowIndex + row;
902 return r < 0 || r >= m_base_mapper.m_inputRows;
905 EIGEN_ALWAYS_INLINE bool padAnyRow(const Index first_row, const Index last_row) const
907 return m_rowIndex + first_row < 0 || m_rowIndex + last_row >= m_base_mapper.m_inputRows;
910 EIGEN_ALWAYS_INLINE bool padOrSkipRow(const Index row, Index *orig_row) const
912 eigen_assert(nonStandardPatches());
914 const Index input_row = m_rowIndex + row * m_base_mapper.m_in_row_strides;
915 *orig_row = (m_base_mapper.m_patch_row_inflate_strides == 1)
917 : ((input_row >= 0) ? (input_row / m_base_mapper.m_fastInputRowStride) : 0);
919 return (*orig_row < 0 || *orig_row >= m_base_mapper.m_inputRows) ||
920 (input_row != *orig_row * m_base_mapper.m_patch_row_inflate_strides);
923 EIGEN_ALWAYS_INLINE bool padCol(const Index col) const
925 const Index c = m_colIndex + col;
926 return c < 0 || c >= m_base_mapper.m_inputCols;
929 EIGEN_ALWAYS_INLINE bool padOrSkipCol(const Index col, Index *orig_col) const
931 eigen_assert(nonStandardPatches());
933 const Index input_col = m_colIndex + col * m_base_mapper.m_in_col_strides;
934 *orig_col = (m_base_mapper.m_patch_col_inflate_strides == 1)
936 : ((input_col >= 0) ? (input_col / m_base_mapper.m_fastInputColStride) : 0);
938 return (*orig_col < 0 || *orig_col >= m_base_mapper.m_inputCols) ||
939 (input_col != *orig_col * m_base_mapper.m_patch_col_inflate_strides);
942 EIGEN_ALWAYS_INLINE Index baseIndex(const Index row, const Index col) const
944 const Index r = m_rowIndex + row;
945 const Index c = m_colIndex + col;
946 return r * m_base_mapper.m_rowInputStride + c * m_base_mapper.m_colInputStride + m_otherIndex;
948 // Compute a base index when original input row and column were precomputed
949 // using padOrSkipRow and padOrSkipCol. Used only for non standard patches.
951 EIGEN_ALWAYS_INLINE Index origBaseIndex(const Index orig_row, const Index orig_col) const
953 return orig_row * m_base_mapper.m_rowInputStride + orig_col * m_base_mapper.m_colInputStride +
958 EIGEN_ALWAYS_INLINE Index rowStride() const { return m_base_mapper.m_row_strides; }
960 EIGEN_ALWAYS_INLINE Index colStride() const { return m_base_mapper.m_col_strides; }
963 EIGEN_ALWAYS_INLINE Index rowOffset() const
965 const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
966 const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
967 return patchOffset - colOffset * m_base_mapper.m_colStride;
971 EIGEN_ALWAYS_INLINE Index colOffset() const
973 const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
974 const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
979 EIGEN_ALWAYS_INLINE Index depthOffset() const { return m_depth_offset % patchDepth(); }
981 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const
983 return LinearMapper(m_base_mapper, i + m_depth_offset, j + m_col_offset);
987 Index m_depth_offset; // First row in the input matrix
988 Index m_col_offset; // First col in the input matrix
990 // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base
991 // indices for the first element in a patch specified by col_offset
992 // (see computeBaseIndices(...) for details).
997 const ParentMapper m_base_mapper; // Keeping a copy instead of a reference
998 // performs better in benchmarks.
1001 // Arrange a block of the right input matrix (in our case it's always a "virtual
1002 // matrix" constructed from extracted image patches) in contiguous memory.
1004 // Given column major input (A0 beside A1 in memory):
1005 // A0 B0 C0 D0 E0 F0 G0 H0 ... Z0
1006 // A1 B1 C1 D1 E1 F1 G1 H1 ... Z1
1007 // A2 B2 C2 D2 E2 F2 G2 H2 ... Z2
1008 // A3 B3 C3 D3 E3 F3 G3 H3 ... Z3
1009 // A4 B4 C4 D4 E4 F4 G4 H4 ... Z4
1010 // A5 B5 C5 D5 E5 F5 G5 H5 ... Z5
1011 // A6 B6 C6 D6 E6 F6 G6 H6 ... Z6
1012 // A7 B7 C7 D7 E7 F7 G7 H7 ... Z7
1016 // *) A, B, C, ... - patches extracted from the original input.
1017 // *) A0, A1, A2 ... - values from the same patch at different offsets.
1019 // The traversal (packed rhs memory) order (B0 besides A0 in memory):
1020 // A0 B0 C0 D0 A1 B1 C1 D1 ...
1021 // E0 F0 G0 H0 E1 F1 G1 H1 ...
1023 // Z0 Z1 Z2 Z3 Z4 Z5 Z6 Z7 ... <- doesn't belong to any block (nr = 4)
1025 // This traversal order must be the same as in default gemm_pack_rhs defined in
1026 // GeneralBlockPanelKernel.h.
1028 // *) nr - number of registers along the 'n' dimension.
1029 // See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix
1030 // Multiplication" paper.
1031 template <typename NewDimension, Index Rows, Index Cols, typename ArgType, typename Device,
1032 typename Scalar, typename Index, typename nocontract_t, typename contract_t,
1033 int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
1035 struct gemm_pack_rhs<
1037 TensorContractionSubMapper<
1040 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>,
1042 nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered,
1044 nr, ColMajor, false, false>
1046 typedef TensorContractionSubMapper<
1049 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>,
1051 nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>
1053 typedef SubMapper DataMapper;
1054 typedef typename packet_traits<Scalar>::type Packet;
1056 EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE)
1059 EIGEN_DONT_INLINE void operator()(Scalar *block, const DataMapper &rhs, Index depth, Index cols,
1060 Index stride = 0, Index offset = 0) const
1062 eigen_assert(stride == 0);
1063 eigen_assert(offset == 0);
1067 const Index packet_cols4 = (cols / 4) * 4;
1068 const Index peeled_k = (depth / packet_size) * packet_size;
1069 const bool non_standard_patches = rhs.nonStandardPatches();
1071 for (Index j2 = 0; j2 < packet_cols4; j2 += 4)
1073 const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1074 const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1075 const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1076 const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1079 if ((packet_size % 4) == 0 && !non_standard_patches)
1082 // Iterate over patch columns and rows, if we know that a single
1083 // packet do not span across multiple rows or columns.
1084 if ((rhs.patchDepth() % packet_size) == 0)
1086 const Index start_col = rhs.colOffset();
1087 const Index max_col = rhs.maxCol(peeled_k);
1089 for (Index c = start_col; c < max_col; ++c)
1091 eigen_assert(k <= peeled_k);
1093 const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
1094 const Index max_row = rhs.maxRow(peeled_k, c);
1096 const bool pad_col0 = dm0.padCol(c);
1097 const bool pad_col1 = dm1.padCol(c);
1098 const bool pad_col2 = dm2.padCol(c);
1099 const bool pad_col3 = dm3.padCol(c);
1101 // Check if we can squeeze reads along the `row` and `depth`
1102 // dimensions (two innermost dimensions).
1103 if (!pad_col0 && !pad_col1 && !pad_col2 && !pad_col3 && //
1104 !dm0.padRow(start_row) && !dm0.padRow(max_row - 1) && //
1105 !dm1.padRow(start_row) && !dm1.padRow(max_row - 1) && //
1106 !dm2.padRow(start_row) && !dm2.padRow(max_row - 1) && //
1107 !dm3.padRow(start_row) && !dm3.padRow(max_row - 1))
1109 // Compute how many elements we can squeeze read.
1110 const Index start_depth = (c == start_col) ? rhs.depthOffset() : 0;
1112 // Upper bound for the number of elements in the depth dimension
1113 // that we can squeeze read.
1114 const Index squeeze_length = (max_row - start_row) * rhs.patchDepth() - start_depth;
1116 // Do not overshoot beyond the block size.
1117 const Index max_depth = start_depth + std::min<Index>(peeled_k - k, squeeze_length);
1118 eigen_assert((max_depth - start_depth) % packet_size == 0);
1120 const Index idx0 = dm0.baseIndex(start_row, c);
1121 const Index idx1 = dm1.baseIndex(start_row, c);
1122 const Index idx2 = dm2.baseIndex(start_row, c);
1123 const Index idx3 = dm3.baseIndex(start_row, c);
1125 for (Index d = start_depth; d < max_depth; d += packet_size)
1127 eigen_assert(k < peeled_k);
1128 PacketBlock<Packet, 4> kernel;
1129 kernel.packet[0] = rhs.packetNoPadding(d, idx0);
1130 kernel.packet[1] = rhs.packetNoPadding(d, idx1);
1131 kernel.packet[2] = rhs.packetNoPadding(d, idx2);
1132 kernel.packet[3] = rhs.packetNoPadding(d, idx3);
1134 pstoreu(block + 0 * packet_size, kernel.packet[0]);
1135 pstoreu(block + 1 * packet_size, kernel.packet[1]);
1136 pstoreu(block + 2 * packet_size, kernel.packet[2]);
1137 pstoreu(block + 3 * packet_size, kernel.packet[3]);
1138 block += 4 * packet_size;
1142 // Go to the next column.
1146 // If we can't squeeze reads, process rows one by one.
1147 for (Index r = start_row; r < max_row; ++r)
1149 eigen_assert(k <= peeled_k);
1151 const bool pad0 = pad_col0 || dm0.padRow(r);
1152 const bool pad1 = pad_col1 || dm1.padRow(r);
1153 const bool pad2 = pad_col2 || dm2.padRow(r);
1154 const bool pad3 = pad_col3 || dm3.padRow(r);
1156 const Index idx0 = dm0.baseIndex(r, c);
1157 const Index idx1 = dm1.baseIndex(r, c);
1158 const Index idx2 = dm2.baseIndex(r, c);
1159 const Index idx3 = dm3.baseIndex(r, c);
1161 const Index start_depth =
1162 ((c == start_col) && (r == start_row)) ? rhs.depthOffset() : 0;
1163 const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
1164 eigen_assert((max_depth - start_depth) % packet_size == 0);
1166 for (Index d = start_depth; d < max_depth; d += packet_size)
1168 eigen_assert(k < peeled_k);
1169 PacketBlock<Packet, 4> kernel;
1170 kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx0);
1171 kernel.packet[1] = pad1 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx1);
1172 kernel.packet[2] = pad2 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx2);
1173 kernel.packet[3] = pad3 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx3);
1175 pstoreu(block + 0 * packet_size, kernel.packet[0]);
1176 pstoreu(block + 1 * packet_size, kernel.packet[1]);
1177 pstoreu(block + 2 * packet_size, kernel.packet[2]);
1178 pstoreu(block + 3 * packet_size, kernel.packet[3]);
1179 block += 4 * packet_size;
1185 // The loop above should fill peeled_k elements.
1186 eigen_assert(peeled_k == k);
1190 for (; k < peeled_k; k += packet_size)
1192 PacketBlock<Packet, 4> kernel;
1193 kernel.packet[0] = dm0.loadPacketStandard(k);
1194 kernel.packet[1] = dm1.loadPacketStandard(k);
1195 kernel.packet[2] = dm2.loadPacketStandard(k);
1196 kernel.packet[3] = dm3.loadPacketStandard(k);
1198 pstoreu(block + 0 * packet_size, kernel.packet[0]);
1199 pstoreu(block + 1 * packet_size, kernel.packet[1]);
1200 pstoreu(block + 2 * packet_size, kernel.packet[2]);
1201 pstoreu(block + 3 * packet_size, kernel.packet[3]);
1202 block += 4 * packet_size;
1207 // Copy the remaining coefficients of the column block after the peeled_k.
1208 if (!rhs.nonStandardPatches())
1210 for (; k < depth; k++)
1212 block[0] = dm0.loadCoeffStandard(k);
1213 block[1] = dm1.loadCoeffStandard(k);
1214 block[2] = dm2.loadCoeffStandard(k);
1215 block[3] = dm3.loadCoeffStandard(k);
1221 for (; k < depth; k++)
1232 // copy the remaining columns one at a time (nr==1)
1233 for (Index j2 = packet_cols4; j2 < cols; ++j2)
1235 const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1236 for (Index k = 0; k < depth; k++)
1245 // Template specialization for packet_size = 2. We must special-case packet
1246 // blocks with nr > packet_size, e.g. PacketBlock<Packet2d, 4>.
1247 template <typename NewDimension, Index Rows, Index Cols, typename ArgType, typename Device,
1248 typename Scalar, typename Index, typename nocontract_t, typename contract_t,
1249 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, int nr>
1250 struct gemm_pack_rhs<
1252 TensorContractionSubMapper<
1255 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>,
1257 nocontract_t, contract_t, 2, inner_dim_contiguous, inner_dim_reordered, Alignment>,
1258 nr, ColMajor, false, false>
1260 typedef TensorContractionSubMapper<
1263 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>,
1265 nocontract_t, contract_t, 2, inner_dim_contiguous, inner_dim_reordered, Alignment>
1267 typedef SubMapper DataMapper;
1268 typedef typename packet_traits<Scalar>::type Packet;
1270 EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE)
1273 EIGEN_DONT_INLINE void operator()(Scalar *block, const DataMapper &rhs, Index depth, Index cols,
1274 Index stride = 0, Index offset = 0) const
1276 eigen_assert(stride == 0);
1277 eigen_assert(offset == 0);
1282 const int packet_size = 2;
1283 const Index packet_cols4 = (cols / 4) * 4;
1284 const Index peeled_k = (depth / packet_size) * packet_size;
1285 const bool non_standard_patches = rhs.nonStandardPatches();
1287 for (Index j2 = 0; j2 < packet_cols4; j2 += 4)
1289 const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1290 const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1291 const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1292 const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1295 if (!non_standard_patches)
1298 // Iterate over patch columns and rows if we know that a single
1299 // packet do not span across multiple rows or columns.
1300 if ((rhs.patchDepth() % packet_size) == 0)
1302 const Index start_col = rhs.colOffset();
1303 const Index max_col = rhs.maxCol(peeled_k);
1305 for (Index c = start_col; c < max_col; ++c)
1307 eigen_assert(k <= peeled_k);
1309 const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
1310 const Index max_row = rhs.maxRow(peeled_k, c);
1312 const bool pad_col0 = dm0.padCol(c);
1313 const bool pad_col1 = dm1.padCol(c);
1314 const bool pad_col2 = dm2.padCol(c);
1315 const bool pad_col3 = dm3.padCol(c);
1317 // We can squeeze reads along the `row` and `depth` dimensions if
1318 // the row stride is `1`, which means that `row` and `depth`
1319 // dimensions are contiguous (two innermost dimensions).
1320 if (rhs.rowStride() == 1 && //
1321 !pad_col0 && !pad_col1 && !pad_col2 && !pad_col3 && //
1322 !dm0.padRow(start_row) && !dm0.padRow(max_row - 1) && //
1323 !dm1.padRow(start_row) && !dm1.padRow(max_row - 1) && //
1324 !dm2.padRow(start_row) && !dm2.padRow(max_row - 1) && //
1325 !dm3.padRow(start_row) && !dm3.padRow(max_row - 1))
1327 // Compute how many elements we can squeeze read.
1328 const Index start_depth = (c == start_col) ? rhs.depthOffset() : 0;
1330 // Upper bound for the number of elements in the depth dimension
1331 // that we can squeeze read.
1332 const Index squeeze_length = (max_row - start_row) * rhs.patchDepth() - start_depth;
1334 // Do not overshoot beyond the block size.
1335 const Index max_depth = start_depth + std::min<Index>(peeled_k - k, squeeze_length);
1336 eigen_assert((max_depth - start_depth) % packet_size == 0);
1338 const Index idx0 = dm0.baseIndex(start_row, c);
1339 const Index idx1 = dm1.baseIndex(start_row, c);
1340 const Index idx2 = dm2.baseIndex(start_row, c);
1341 const Index idx3 = dm3.baseIndex(start_row, c);
1343 for (Index d = start_depth; d < max_depth; d += packet_size)
1345 PacketBlock<Packet, 2> kernel0;
1346 PacketBlock<Packet, 2> kernel1;
1347 kernel0.packet[0] = rhs.packetNoPadding(d, idx0);
1348 kernel0.packet[1] = rhs.packetNoPadding(d, idx1);
1349 kernel1.packet[0] = rhs.packetNoPadding(d, idx2);
1350 kernel1.packet[1] = rhs.packetNoPadding(d, idx3);
1351 ptranspose(kernel0);
1352 ptranspose(kernel1);
1353 pstoreu(block + 0 * packet_size, kernel0.packet[0]);
1354 pstoreu(block + 1 * packet_size, kernel1.packet[0]);
1355 pstoreu(block + 2 * packet_size, kernel0.packet[1]);
1356 pstoreu(block + 3 * packet_size, kernel1.packet[1]);
1357 block += 4 * packet_size;
1361 // Go to the next column.
1365 // If we can't squeeze reads, process rows one by one.
1366 for (Index r = start_row; r < max_row; ++r)
1368 eigen_assert(k <= peeled_k);
1370 const bool pad0 = pad_col0 || dm0.padRow(r);
1371 const bool pad1 = pad_col1 || dm1.padRow(r);
1372 const bool pad2 = pad_col2 || dm2.padRow(r);
1373 const bool pad3 = pad_col3 || dm3.padRow(r);
1375 const Index idx0 = dm0.baseIndex(r, c);
1376 const Index idx1 = dm1.baseIndex(r, c);
1377 const Index idx2 = dm2.baseIndex(r, c);
1378 const Index idx3 = dm3.baseIndex(r, c);
1380 const Index start_depth =
1381 ((c == start_col) && (r == start_row)) ? rhs.depthOffset() : 0;
1382 const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
1383 eigen_assert((max_depth - start_depth) % packet_size == 0);
1385 for (Index d = start_depth; d < max_depth; d += packet_size)
1387 eigen_assert(k < peeled_k);
1388 PacketBlock<Packet, 2> kernel0;
1389 PacketBlock<Packet, 2> kernel1;
1390 kernel0.packet[0] = pad0 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx0);
1391 kernel0.packet[1] = pad1 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx1);
1392 kernel1.packet[0] = pad2 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx2);
1393 kernel1.packet[1] = pad3 ? pset1<Packet>(Scalar(0)) : rhs.packetNoPadding(d, idx3);
1394 ptranspose(kernel0);
1395 ptranspose(kernel1);
1396 pstoreu(block + 0 * packet_size, kernel0.packet[0]);
1397 pstoreu(block + 1 * packet_size, kernel1.packet[0]);
1398 pstoreu(block + 2 * packet_size, kernel0.packet[1]);
1399 pstoreu(block + 3 * packet_size, kernel1.packet[1]);
1400 block += 4 * packet_size;
1406 // The loop above should fill peeled_k elements.
1407 eigen_assert(peeled_k == k);
1411 // Packet can span multiple rows or columns, so we have to go
1412 // though the slower "standard" path.
1413 for (; k < peeled_k; k += packet_size)
1415 PacketBlock<Packet, 2> kernel0;
1416 PacketBlock<Packet, 2> kernel1;
1417 kernel0.packet[0] = dm0.loadPacketStandard(k);
1418 kernel0.packet[1] = dm1.loadPacketStandard(k);
1419 kernel1.packet[0] = dm2.loadPacketStandard(k);
1420 kernel1.packet[1] = dm3.loadPacketStandard(k);
1421 ptranspose(kernel0);
1422 ptranspose(kernel1);
1423 pstoreu(block + 0 * packet_size, kernel0.packet[0]);
1424 pstoreu(block + 1 * packet_size, kernel1.packet[0]);
1425 pstoreu(block + 2 * packet_size, kernel0.packet[1]);
1426 pstoreu(block + 3 * packet_size, kernel1.packet[1]);
1427 block += 4 * packet_size;
1432 // Copy the remaining coefficients of the column block after the peeled_k.
1433 if (!non_standard_patches)
1435 for (; k < depth; k++)
1437 block[0] = dm0.loadCoeffStandard(k);
1438 block[1] = dm1.loadCoeffStandard(k);
1439 block[2] = dm2.loadCoeffStandard(k);
1440 block[3] = dm3.loadCoeffStandard(k);
1446 for (; k < depth; k++)
1457 // Copy the remaining columns one at a time (nr==1).
1458 for (Index j2 = packet_cols4; j2 < cols; ++j2)
1460 const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1461 for (Index k = 0; k < depth; k++)
1470 // Special case for non-vectorized types such as float16.
1471 template <typename NewDimension, Index Rows, Index Cols, typename ArgType, typename Device,
1472 typename Scalar, typename Index, typename nocontract_t, typename contract_t,
1473 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, int nr>
1474 struct gemm_pack_rhs<
1476 TensorContractionSubMapper<
1479 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>,
1481 nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment>,
1482 nr, ColMajor, false, false>
1484 typedef TensorContractionSubMapper<
1487 const TensorReshapingOp<NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType>>,
1489 nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment>
1491 typedef SubMapper DataMapper;
1493 EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE)
1496 EIGEN_DONT_INLINE void operator()(Scalar *block, const DataMapper &rhs, Index depth, Index cols,
1497 Index stride = 0, Index offset = 0) const
1499 eigen_assert(stride == 0);
1500 eigen_assert(offset == 0);
1505 const Index packet_cols4 = (cols / 4) * 4;
1507 for (Index j2 = 0; j2 < packet_cols4; j2 += 4)
1509 const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1510 const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1511 const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1512 const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1514 if (!rhs.nonStandardPatches())
1516 for (Index k = 0; k < depth; k++)
1518 block[0] = dm0.loadCoeffStandard(k);
1519 block[1] = dm1.loadCoeffStandard(k);
1520 block[2] = dm2.loadCoeffStandard(k);
1521 block[3] = dm3.loadCoeffStandard(k);
1527 for (Index k = 0; k < depth; k++)
1538 // Copy the remaining columns one at a time (nr==1).
1539 for (Index j2 = packet_cols4; j2 < cols; ++j2)
1541 const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1542 for (Index k = 0; k < depth; k++)
1550 } // end namespace internal
1552 /** SpatialConvolution
1553 * \ingroup CXX11_NeuralNetworks_Module
1555 * \brief Applies a 2D convolution over a multichannel input image.
1557 * The input parameter is expected to be a tensor with a rank of 3 or more
1558 * (channels, height, width, and optionally others)
1559 * The kernel parameter is expected to be a 4D tensor (filters, channels,
1560 * kernel_height, kernel_width)
1561 * The input and the kernel must both be in col-major layout. The result will
1562 * also be in col-major layout.
1564 * If col_in_stride, row_in_stride > 1, then applies convolution with holes
1565 * (aka atrous convolution), sampling every col_in_stride, row_in_stride input
1568 * If padding_top, padding_bottom, padding_left, or padding_right is specified,
1569 * then those paddings will be used to pad the input, and padding_type must be
1572 * The result can be assigned to a tensor of rank equal to the rank of the
1573 * input. The dimensions of the result will be filters, height, width (and
1574 * others if applicable).
1576 * It is possible to swap the order of the width and height dimensions provided
1577 * that the same order is used in the input, the kernel, and the output.
1579 * It is also possible to add an output kernel to the contraction, output
1580 * kernel is called by Eigen when it "finalizes" the block of an output tensor.
1583 template <typename Input, typename Kernel, typename OutputKernel = const NoOpOutputKernel>
1584 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static const typename internal::conditional<
1585 internal::traits<Input>::Layout == ColMajor,
1587 const DSizes<typename internal::traits<Input>::Index,
1588 internal::traits<Input>::NumDimensions>,
1589 const TensorContractionOp<
1590 const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
1591 const TensorReshapingOp<const DSizes<typename internal::traits<Input>::Index, 2>,
1593 const TensorReshapingOp<const DSizes<typename internal::traits<Input>::Index, 2>,
1594 const TensorImagePatchOp<Dynamic, Dynamic, const Input>>,
1595 const OutputKernel>>,
1597 const DSizes<typename internal::traits<Input>::Index,
1598 internal::traits<Input>::NumDimensions>,
1599 const TensorContractionOp<
1600 const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
1601 const TensorReshapingOp<const DSizes<typename internal::traits<Input>::Index, 2>,
1602 const TensorImagePatchOp<Dynamic, Dynamic, const Input>>,
1603 const TensorReshapingOp<const DSizes<typename internal::traits<Input>::Index, 2>,
1605 const OutputKernel>>>::type
1606 SpatialConvolution(const Input &input, const Kernel &kernel, const Index row_stride = 1,
1607 const Index col_stride = 1, const PaddingType padding_type = PADDING_SAME,
1608 const Index row_in_stride = 1, const Index col_in_stride = 1,
1609 const OutputKernel &output_kernel = OutputKernel(), Index padding_top = 0,
1610 Index padding_bottom = 0, Index padding_left = 0, Index padding_right = 0)
1612 typedef typename internal::traits<Input>::Index TensorIndex;
1613 TensorRef<Tensor<typename internal::traits<Input>::Scalar, internal::traits<Input>::NumDimensions,
1614 internal::traits<Input>::Layout, TensorIndex>>
1617 Tensor<typename internal::traits<Kernel>::Scalar, internal::traits<Kernel>::NumDimensions,
1618 internal::traits<Kernel>::Layout, TensorIndex>>
1621 EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout == internal::traits<Kernel>::Layout,
1622 YOU_MADE_A_PROGRAMMING_MISTAKE)
1623 const bool isColMajor = (internal::traits<Input>::Layout == ColMajor);
1625 const int NumDims = internal::traits<Input>::NumDimensions;
1627 // Number of filters to apply. This is the same as the output depth of the
1629 const TensorIndex kernelFilters = isColMajor ? kern.dimensions()[0] : kern.dimensions()[3];
1630 // Number of channels. This is the same as the input depth.
1631 const TensorIndex kernelChannels = isColMajor ? kern.dimensions()[1] : kern.dimensions()[2];
1632 const TensorIndex kernelRows = isColMajor ? kern.dimensions()[2] : kern.dimensions()[1];
1633 const TensorIndex kernelCols = isColMajor ? kern.dimensions()[3] : kern.dimensions()[0];
1635 const Index kernelRowsEff = kernelRows + (kernelRows - 1) * (row_in_stride - 1);
1636 const Index kernelColsEff = kernelCols + (kernelCols - 1) * (col_in_stride - 1);
1638 array<IndexPair<TensorIndex>, 1> contract_dims;
1639 contract_dims[0] = IndexPair<TensorIndex>(1, 0);
1641 const TensorIndex InputRows = isColMajor ? in.dimension(1) : in.dimension(NumDims - 2);
1642 const TensorIndex InputCols = isColMajor ? in.dimension(2) : in.dimension(NumDims - 3);
1643 const bool padding_explicit = (padding_top || padding_bottom || padding_left || padding_right);
1645 TensorIndex out_height;
1646 TensorIndex out_width;
1647 switch (padding_type)
1651 const TensorIndex InputRowsEff = InputRows + padding_top + padding_bottom;
1652 const TensorIndex InputColsEff = InputCols + padding_left + padding_right;
1653 out_height = divup(InputRowsEff - kernelRowsEff + 1, row_stride);
1654 out_width = divup(InputColsEff - kernelColsEff + 1, col_stride);
1659 eigen_assert(!padding_explicit);
1660 out_height = divup(InputRows, row_stride);
1661 out_width = divup(InputCols, col_stride);
1666 // Initialize unused variables to avoid a compiler warning
1669 eigen_assert(false && "unexpected padding");
1673 // Molds the output of the patch extraction code into a 2d tensor:
1674 // - the first dimension (dims[0]): the patch values to be multiplied with the
1676 // - the second dimension (dims[1]): everything else
1677 DSizes<TensorIndex, 2> pre_contract_dims;
1680 pre_contract_dims[0] = kernelChannels * kernelRows * kernelCols;
1681 pre_contract_dims[1] = out_height * out_width;
1682 for (int i = 3; i < NumDims; ++i)
1684 pre_contract_dims[1] *= in.dimension(i);
1689 pre_contract_dims[1] = kernelChannels * kernelRows * kernelCols;
1690 pre_contract_dims[0] = out_height * out_width;
1691 for (int i = 0; i < NumDims - 3; ++i)
1693 pre_contract_dims[0] *= in.dimension(i);
1697 // Molds the output of the contraction into the shape expected by the used
1698 // (assuming this is ColMajor):
1699 // - 1st dim: kernel filters
1700 // - 2nd dim: output height
1701 // - 3rd dim: output width
1702 // - 4th dim and beyond: everything else including batch size
1703 DSizes<TensorIndex, NumDims> post_contract_dims;
1706 post_contract_dims[0] = kernelFilters;
1707 post_contract_dims[1] = out_height;
1708 post_contract_dims[2] = out_width;
1709 for (int i = 3; i < NumDims; ++i)
1711 post_contract_dims[i] = in.dimension(i);
1716 post_contract_dims[NumDims - 1] = kernelFilters;
1717 post_contract_dims[NumDims - 2] = out_height;
1718 post_contract_dims[NumDims - 3] = out_width;
1719 for (int i = 0; i < NumDims - 3; ++i)
1721 post_contract_dims[i] = in.dimension(i);
1725 DSizes<TensorIndex, 2> kernel_dims;
1728 kernel_dims[0] = kernelFilters;
1729 kernel_dims[1] = kernelChannels * kernelRows * kernelCols;
1733 kernel_dims[0] = kernelChannels * kernelRows * kernelCols;
1734 kernel_dims[1] = kernelFilters;
1736 if (padding_explicit)
1739 Cond<internal::traits<Input>::Layout == ColMajor>(),
1740 kernel.reshape(kernel_dims)
1742 .extract_image_patches(kernelRows, kernelCols, row_stride, col_stride,
1743 row_in_stride, col_in_stride,
1744 /*row_inflate_stride=*/1,
1745 /*col_inflate_stride=*/1, padding_top,
1746 padding_bottom, padding_left, padding_right,
1747 /*padding_value=*/0)
1748 .reshape(pre_contract_dims),
1749 contract_dims, output_kernel)
1750 .reshape(post_contract_dims),
1752 .extract_image_patches(
1753 kernelRows, kernelCols, row_stride, col_stride, row_in_stride, col_in_stride,
1754 /*row_inflate_stride=*/1,
1755 /*col_inflate_stride=*/1, padding_top, padding_bottom, padding_left, padding_right,
1756 /*padding_value=*/0)
1757 .reshape(pre_contract_dims)
1758 .contract(kernel.reshape(kernel_dims), contract_dims, output_kernel)
1759 .reshape(post_contract_dims));
1764 Cond<internal::traits<Input>::Layout == ColMajor>(),
1765 kernel.reshape(kernel_dims)
1767 .extract_image_patches(kernelRows, kernelCols, row_stride, col_stride,
1768 row_in_stride, col_in_stride, padding_type)
1769 .reshape(pre_contract_dims),
1770 contract_dims, output_kernel)
1771 .reshape(post_contract_dims),
1773 .extract_image_patches(kernelRows, kernelCols, row_stride, col_stride, row_in_stride,
1774 col_in_stride, padding_type)
1775 .reshape(pre_contract_dims)
1776 .contract(kernel.reshape(kernel_dims), contract_dims, output_kernel)
1777 .reshape(post_contract_dims));
1781 } // end namespace Eigen
1783 #endif // __NNFW_CKER_EIGEN_EIGEN_SPATIAL_CONVOLUTIONS_INL_H__