IVGCVSW-4449 Add missing QLstm nullptr checks
[platform/upstream/armnn.git] / src / armnn / layers / QLstmLayer.cpp
1 //
2 // Copyright © 2020 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "QLstmLayer.hpp"
6
7 #include "LayerCloneBase.hpp"
8
9 #include <armnn/LstmParams.hpp>
10 #include <armnn/TypesUtils.hpp>
11 #include <backendsCommon/CpuTensorHandle.hpp>
12 #include <backendsCommon/WorkloadFactory.hpp>
13
14 namespace armnn
15 {
16
17 QLstmLayer::QLstmLayer(const QLstmDescriptor& param, const char* name)
18         : LayerWithParameters(3, 3, LayerType::QLstm, param, name)
19 {
20 }
21
22 std::unique_ptr<IWorkload> QLstmLayer::CreateWorkload(const IWorkloadFactory& factory) const
23 {
24     QLstmQueueDescriptor descriptor;
25
26     // Basic parameters
27     descriptor.m_InputToForgetWeights = m_BasicParameters.m_InputToForgetWeights.get();
28     descriptor.m_InputToCellWeights = m_BasicParameters.m_InputToCellWeights.get();
29     descriptor.m_InputToOutputWeights = m_BasicParameters.m_InputToOutputWeights.get();
30     descriptor.m_RecurrentToForgetWeights = m_BasicParameters.m_RecurrentToForgetWeights.get();
31     descriptor.m_RecurrentToCellWeights = m_BasicParameters.m_RecurrentToCellWeights.get();
32     descriptor.m_RecurrentToOutputWeights = m_BasicParameters.m_RecurrentToOutputWeights.get();
33     descriptor.m_ForgetGateBias = m_BasicParameters.m_ForgetGateBias.get();
34     descriptor.m_CellBias = m_BasicParameters.m_CellBias.get();
35     descriptor.m_OutputGateBias = m_BasicParameters.m_OutputGateBias.get();
36
37     // CIFG parameters
38     if (!m_Param.m_CifgEnabled)
39     {
40         descriptor.m_InputToInputWeights     = m_CifgParameters.m_InputToInputWeights.get();
41         descriptor.m_RecurrentToInputWeights = m_CifgParameters.m_RecurrentToInputWeights.get();
42         descriptor.m_InputGateBias           = m_CifgParameters.m_InputGateBias.get();
43     }
44
45     // Projection parameters
46     if (m_Param.m_ProjectionEnabled)
47     {
48         descriptor.m_ProjectionWeights = m_ProjectionParameters.m_ProjectionWeights.get();
49         descriptor.m_ProjectionBias    = m_ProjectionParameters.m_ProjectionBias.get();
50     }
51
52     // Peephole parameters
53     if (m_Param.m_PeepholeEnabled)
54     {
55         if (!m_Param.m_CifgEnabled)
56         {
57             descriptor.m_CellToInputWeights = m_PeepholeParameters.m_CellToInputWeights.get();
58         }
59
60         descriptor.m_CellToForgetWeights = m_PeepholeParameters.m_CellToForgetWeights.get();
61         descriptor.m_CellToOutputWeights = m_PeepholeParameters.m_CellToOutputWeights.get();
62     }
63
64     // Layer normalisation parameters
65     if(m_Param.m_LayerNormEnabled)
66     {
67         if (!m_Param.m_CifgEnabled)
68         {
69             descriptor.m_InputLayerNormWeights = m_LayerNormParameters.m_InputLayerNormWeights.get();
70         }
71         descriptor.m_ForgetLayerNormWeights = m_LayerNormParameters.m_ForgetLayerNormWeights.get();
72         descriptor.m_CellLayerNormWeights   = m_LayerNormParameters.m_CellLayerNormWeights.get();
73         descriptor.m_OutputLayerNormWeights = m_LayerNormParameters.m_OutputLayerNormWeights.get();
74     }
75
76     return factory.CreateQLstm(descriptor, PrepInfoAndDesc(descriptor));
77 }
78
79 QLstmLayer* QLstmLayer::Clone(Graph& graph) const
80 {
81     auto layer = CloneBase<QLstmLayer>(graph, m_Param, GetName());
82
83     layer->m_BasicParameters.m_InputToForgetWeights = m_BasicParameters.m_InputToForgetWeights ?
84             std::make_unique<ScopedCpuTensorHandle>(*m_BasicParameters.m_InputToForgetWeights) : nullptr;
85     layer->m_BasicParameters.m_InputToCellWeights = m_BasicParameters.m_InputToCellWeights ?
86             std::make_unique<ScopedCpuTensorHandle>(*m_BasicParameters.m_InputToCellWeights) : nullptr;
87     layer->m_BasicParameters.m_InputToOutputWeights = m_BasicParameters.m_InputToOutputWeights ?
88             std::make_unique<ScopedCpuTensorHandle>(*m_BasicParameters.m_InputToOutputWeights) : nullptr;
89     layer->m_BasicParameters.m_RecurrentToForgetWeights = m_BasicParameters.m_RecurrentToForgetWeights ?
90             std::make_unique<ScopedCpuTensorHandle>(*m_BasicParameters.m_RecurrentToForgetWeights) : nullptr;
91     layer->m_BasicParameters.m_RecurrentToCellWeights = m_BasicParameters.m_RecurrentToCellWeights ?
92             std::make_unique<ScopedCpuTensorHandle>(*m_BasicParameters.m_RecurrentToCellWeights) : nullptr;
93     layer->m_BasicParameters.m_RecurrentToOutputWeights = m_BasicParameters.m_RecurrentToOutputWeights ?
94             std::make_unique<ScopedCpuTensorHandle>(*m_BasicParameters.m_RecurrentToOutputWeights) : nullptr;
95     layer->m_BasicParameters.m_ForgetGateBias = m_BasicParameters.m_ForgetGateBias ?
96             std::make_unique<ScopedCpuTensorHandle>(*m_BasicParameters.m_ForgetGateBias) : nullptr;
97     layer->m_BasicParameters.m_CellBias = m_BasicParameters.m_CellBias ?
98             std::make_unique<ScopedCpuTensorHandle>(*m_BasicParameters.m_CellBias) : nullptr;
99     layer->m_BasicParameters.m_OutputGateBias = m_BasicParameters.m_OutputGateBias ?
100             std::make_unique<ScopedCpuTensorHandle>(*m_BasicParameters.m_OutputGateBias) : nullptr;
101
102     if (!m_Param.m_CifgEnabled)
103     {
104         layer->m_CifgParameters.m_InputToInputWeights = m_CifgParameters.m_InputToInputWeights ?
105                 std::make_unique<ScopedCpuTensorHandle>(*m_CifgParameters.m_InputToInputWeights) : nullptr;
106         layer->m_CifgParameters.m_RecurrentToInputWeights = m_CifgParameters.m_RecurrentToInputWeights ?
107                 std::make_unique<ScopedCpuTensorHandle>(*m_CifgParameters.m_RecurrentToInputWeights) : nullptr;
108         layer->m_CifgParameters.m_InputGateBias = m_CifgParameters.m_InputGateBias ?
109                 std::make_unique<ScopedCpuTensorHandle>(*m_CifgParameters.m_InputGateBias) : nullptr;
110     }
111
112     if (m_Param.m_ProjectionEnabled)
113     {
114         layer->m_ProjectionParameters.m_ProjectionWeights = m_ProjectionParameters.m_ProjectionWeights ?
115                 std::make_unique<ScopedCpuTensorHandle>(*m_ProjectionParameters.m_ProjectionWeights) : nullptr;
116         layer->m_ProjectionParameters.m_ProjectionBias = m_ProjectionParameters.m_ProjectionBias ?
117                 std::make_unique<ScopedCpuTensorHandle>(*m_ProjectionParameters.m_ProjectionBias) : nullptr;
118     }
119
120     if (m_Param.m_PeepholeEnabled)
121     {
122         if (!m_Param.m_CifgEnabled) {
123             layer->m_PeepholeParameters.m_CellToInputWeights = m_PeepholeParameters.m_CellToInputWeights ?
124                     std::make_unique<ScopedCpuTensorHandle>(*m_PeepholeParameters.m_CellToInputWeights) : nullptr;
125         }
126
127         layer->m_PeepholeParameters.m_CellToForgetWeights = m_PeepholeParameters.m_CellToForgetWeights ?
128                 std::make_unique<ScopedCpuTensorHandle>(*m_PeepholeParameters.m_CellToForgetWeights) : nullptr;
129         layer->m_PeepholeParameters.m_CellToOutputWeights = m_PeepholeParameters.m_CellToOutputWeights ?
130                 std::make_unique<ScopedCpuTensorHandle>(*m_PeepholeParameters.m_CellToOutputWeights) : nullptr;
131     }
132
133     if (m_Param.m_LayerNormEnabled)
134     {
135         if (!m_Param.m_CifgEnabled) {
136             layer->m_LayerNormParameters.m_InputLayerNormWeights = m_LayerNormParameters.m_InputLayerNormWeights ?
137                     std::make_unique<ScopedCpuTensorHandle>(*m_LayerNormParameters.m_InputLayerNormWeights) : nullptr;
138         }
139
140         layer->m_LayerNormParameters.m_ForgetLayerNormWeights = m_LayerNormParameters.m_ForgetLayerNormWeights ?
141                 std::make_unique<ScopedCpuTensorHandle>(*m_LayerNormParameters.m_ForgetLayerNormWeights) : nullptr;
142         layer->m_LayerNormParameters.m_CellLayerNormWeights = m_LayerNormParameters.m_CellLayerNormWeights ?
143                 std::make_unique<ScopedCpuTensorHandle>(*m_LayerNormParameters.m_CellLayerNormWeights) : nullptr;
144         layer->m_LayerNormParameters.m_OutputLayerNormWeights = m_LayerNormParameters.m_OutputLayerNormWeights ?
145                 std::make_unique<ScopedCpuTensorHandle>(*m_LayerNormParameters.m_OutputLayerNormWeights) : nullptr;
146     }
147
148     return std::move(layer);
149 }
150
151 std::vector<TensorShape> QLstmLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
152 {
153     ARMNN_ASSERT(inputShapes.size() == 3);
154
155     // Get input values for validation
156     unsigned int batchSize = inputShapes[0][0];
157     unsigned int outputSize = inputShapes[1][1];
158     unsigned int numUnits = inputShapes[2][1];
159
160     std::vector<TensorShape> outShapes;
161     outShapes.push_back(TensorShape({ batchSize, outputSize })); // outputStateOut
162     outShapes.push_back(TensorShape({ batchSize, numUnits })); // cellStateOut
163     outShapes.push_back(TensorShape({ batchSize, outputSize })); // output
164
165     return outShapes;
166 }
167
168 void QLstmLayer::ValidateTensorShapesFromInputs()
169 {
170     VerifyLayerConnections(3, CHECK_LOCATION());
171
172     auto inferredShapes = InferOutputShapes(
173     {
174         GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(), // input
175         GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape(), // previousOutputIn
176         GetInputSlot(2).GetConnection()->GetTensorInfo().GetShape() //  previousCellStateIn
177     });
178
179     ARMNN_ASSERT(inferredShapes.size() == 3);
180
181     // Check if the weights are nullptr for basic params
182     ARMNN_ASSERT_MSG(m_BasicParameters.m_InputToForgetWeights != nullptr,
183             "QLstmLayer: m_BasicParameters.m_InputToForgetWeights should not be null.");
184     ARMNN_ASSERT_MSG(m_BasicParameters.m_InputToCellWeights != nullptr,
185             "QLstmLayer: m_BasicParameters.m_InputToCellWeights should not be null.");
186     ARMNN_ASSERT_MSG(m_BasicParameters.m_InputToOutputWeights != nullptr,
187             "QLstmLayer: m_BasicParameters.m_InputToOutputWeights should not be null.");
188     ARMNN_ASSERT_MSG(m_BasicParameters.m_RecurrentToForgetWeights != nullptr,
189             "QLstmLayer: m_BasicParameters.m_RecurrentToForgetWeights should not be null.");
190     ARMNN_ASSERT_MSG(m_BasicParameters.m_RecurrentToCellWeights != nullptr,
191             "QLstmLayer: m_BasicParameters.m_RecurrentToCellWeights should not be null.");
192     ARMNN_ASSERT_MSG(m_BasicParameters.m_RecurrentToOutputWeights != nullptr,
193             "QLstmLayer: m_BasicParameters.m_RecurrentToOutputWeights should not be null.");
194     ARMNN_ASSERT_MSG(m_BasicParameters.m_ForgetGateBias != nullptr,
195             "QLstmLayer: m_BasicParameters.m_ForgetGateBias should not be null.");
196     ARMNN_ASSERT_MSG(m_BasicParameters.m_CellBias != nullptr,
197             "QLstmLayer: m_BasicParameters.m_CellBias should not be null.");
198     ARMNN_ASSERT_MSG(m_BasicParameters.m_OutputGateBias != nullptr,
199             "QLstmLayer: m_BasicParameters.m_OutputGateBias should not be null.");
200
201     if (!m_Param.m_CifgEnabled)
202     {
203         ARMNN_ASSERT_MSG(m_CifgParameters.m_InputToInputWeights != nullptr,
204                 "QLstmLayer: m_CifgParameters.m_InputToInputWeights should not be null.");
205         ARMNN_ASSERT_MSG(m_CifgParameters.m_RecurrentToInputWeights != nullptr,
206                 "QLstmLayer: m_CifgParameters.m_RecurrentToInputWeights should not be null.");
207         ARMNN_ASSERT_MSG(m_CifgParameters.m_InputGateBias != nullptr,
208                 "QLstmLayer: m_CifgParameters.m_InputGateBias should not be null.");
209
210         ConditionalThrowIfNotEqual<LayerValidationException>(
211                 "QLstmLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.",
212                 GetOutputSlot(0).GetTensorInfo().GetShape(),
213                 inferredShapes[0]);
214     }
215     else
216     {
217         ARMNN_ASSERT_MSG(m_CifgParameters.m_InputToInputWeights == nullptr,
218                 "QLstmLayer: m_CifgParameters.m_InputToInputWeights should not have a value when CIFG is enabled.");
219         ARMNN_ASSERT_MSG(m_CifgParameters.m_RecurrentToInputWeights == nullptr,
220                 "QLstmLayer: m_CifgParameters.m_RecurrentToInputWeights should "
221                              "not have a value when CIFG is enabled.");
222         ARMNN_ASSERT_MSG(m_CifgParameters.m_InputGateBias == nullptr,
223                 "QLstmLayer: m_CifgParameters.m_InputGateBias should not have a value when CIFG is enabled.");
224
225         ConditionalThrowIfNotEqual<LayerValidationException>(
226                 "QLstmLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.",
227                 GetOutputSlot(0).GetTensorInfo().GetShape(),
228                 inferredShapes[0]);
229     }
230
231     if (m_Param.m_ProjectionEnabled)
232     {
233         ARMNN_ASSERT_MSG(m_ProjectionParameters.m_ProjectionWeights != nullptr,
234                          "QLstmLayer: m_ProjectionParameters.m_ProjectionWeights should not be null.");
235     }
236
237     if (m_Param.m_PeepholeEnabled)
238     {
239         if (!m_Param.m_CifgEnabled) {
240             ARMNN_ASSERT_MSG(m_PeepholeParameters.m_CellToInputWeights != nullptr,
241                     "QLstmLayer: m_PeepholeParameters.m_CellToInputWeights should not be null "
242                     "when Peephole is enabled and CIFG is disabled.");
243         }
244
245         ARMNN_ASSERT_MSG(m_PeepholeParameters.m_CellToForgetWeights != nullptr,
246                          "QLstmLayer: m_PeepholeParameters.m_CellToForgetWeights should not be null.");
247         ARMNN_ASSERT_MSG(m_PeepholeParameters.m_CellToOutputWeights != nullptr,
248                          "QLstmLayer: m_PeepholeParameters.m_CellToOutputWeights should not be null.");
249     }
250
251     ConditionalThrowIfNotEqual<LayerValidationException>(
252             "QLstmLayer: TensorShape set on OutputSlot[1] does not match the inferred shape.",
253             GetOutputSlot(1).GetTensorInfo().GetShape(),
254             inferredShapes[1]);
255     ConditionalThrowIfNotEqual<LayerValidationException>(
256             "QLstmLayer: TensorShape set on OutputSlot[2] does not match the inferred shape.",
257             GetOutputSlot(2).GetTensorInfo().GetShape(),
258             inferredShapes[2]);
259
260     if (m_Param.m_LayerNormEnabled)
261     {
262         if(!m_Param.m_CifgEnabled)
263         {
264             ARMNN_ASSERT_MSG(m_LayerNormParameters.m_InputLayerNormWeights != nullptr,
265                              "QLstmLayer: m_LayerNormParameters.m_InputLayerNormWeights should not be null.");
266         }
267         ARMNN_ASSERT_MSG(m_LayerNormParameters.m_ForgetLayerNormWeights != nullptr,
268                          "QLstmLayer: m_LayerNormParameters.m_ForgetLayerNormWeights should not be null.");
269         ARMNN_ASSERT_MSG(m_LayerNormParameters.m_CellLayerNormWeights != nullptr,
270                          "QLstmLayer: m_LayerNormParameters.m_CellLayerNormWeights should not be null.");
271         ARMNN_ASSERT_MSG(m_LayerNormParameters.m_OutputLayerNormWeights != nullptr,
272                          "QLstmLayer: m_LayerNormParameters.m_UutputLayerNormWeights should not be null.");
273     }
274 }
275
276 Layer::ConstantTensors QLstmLayer::GetConstantTensorsByRef()
277 {
278     return {m_BasicParameters.m_InputToForgetWeights,
279             m_BasicParameters.m_InputToCellWeights,
280             m_BasicParameters.m_InputToOutputWeights,
281             m_BasicParameters.m_RecurrentToForgetWeights,
282             m_BasicParameters.m_RecurrentToCellWeights,
283             m_BasicParameters.m_RecurrentToOutputWeights,
284             m_BasicParameters.m_ForgetGateBias,
285             m_BasicParameters.m_CellBias,
286             m_BasicParameters.m_OutputGateBias,
287
288             // Cifg parameters
289             m_CifgParameters.m_InputToInputWeights,
290             m_CifgParameters.m_RecurrentToInputWeights,
291             m_CifgParameters.m_InputGateBias,
292
293             // Projection parameters
294             m_ProjectionParameters.m_ProjectionWeights,
295             m_ProjectionParameters.m_ProjectionBias,
296
297             // Peephole parameters
298             m_PeepholeParameters.m_CellToInputWeights,
299             m_PeepholeParameters.m_CellToForgetWeights,
300             m_PeepholeParameters.m_CellToOutputWeights,
301
302             // Layer normalisation parameters
303             m_LayerNormParameters.m_InputLayerNormWeights,
304             m_LayerNormParameters.m_ForgetLayerNormWeights,
305             m_LayerNormParameters.m_CellLayerNormWeights,
306             m_LayerNormParameters.m_OutputLayerNormWeights};
307 }
308
309 void QLstmLayer::Accept(ILayerVisitor& visitor) const
310 {
311     LstmInputParams inputParams;
312
313     ConstTensor inputToInputWeightsTensor;
314     if (m_CifgParameters.m_InputToInputWeights != nullptr)
315     {
316         ConstTensor inputToInputWeightsTensorCopy(m_CifgParameters.m_InputToInputWeights->GetTensorInfo(),
317                                                   m_CifgParameters.m_InputToInputWeights->Map(true));
318         inputToInputWeightsTensor = inputToInputWeightsTensorCopy;
319         inputParams.m_InputToInputWeights = &inputToInputWeightsTensor;
320     }
321
322     ConstTensor inputToForgetWeightsTensor;
323     if (m_BasicParameters.m_InputToForgetWeights != nullptr)
324     {
325         ConstTensor inputToForgetWeightsTensorCopy(m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(),
326                                                    m_BasicParameters.m_InputToForgetWeights->Map(true));
327         inputToForgetWeightsTensor = inputToForgetWeightsTensorCopy;
328         inputParams.m_InputToForgetWeights = &inputToForgetWeightsTensor;
329     }
330
331     ConstTensor inputToCellWeightsTensor;
332     if (m_BasicParameters.m_InputToCellWeights != nullptr)
333     {
334         ConstTensor inputToCellWeightsTensorCopy(m_BasicParameters.m_InputToCellWeights->GetTensorInfo(),
335                                                  m_BasicParameters.m_InputToCellWeights->Map(true));
336         inputToCellWeightsTensor = inputToCellWeightsTensorCopy;
337         inputParams.m_InputToCellWeights = &inputToCellWeightsTensor;
338     }
339
340     ConstTensor inputToOutputWeightsTensor;
341     if (m_BasicParameters.m_InputToOutputWeights != nullptr)
342     {
343         ConstTensor inputToOutputWeightsTensorCopy(m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(),
344                                                    m_BasicParameters.m_InputToOutputWeights->Map(true));
345         inputToOutputWeightsTensor = inputToOutputWeightsTensorCopy;
346         inputParams.m_InputToOutputWeights = &inputToOutputWeightsTensor;
347     }
348
349     ConstTensor recurrentToInputWeightsTensor;
350     if (m_CifgParameters.m_RecurrentToInputWeights != nullptr)
351     {
352         ConstTensor recurrentToInputWeightsTensorCopy(
353                 m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(),
354                 m_CifgParameters.m_RecurrentToInputWeights->Map(true));
355         recurrentToInputWeightsTensor = recurrentToInputWeightsTensorCopy;
356         inputParams.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor;
357     }
358
359     ConstTensor recurrentToForgetWeightsTensor;
360     if (m_BasicParameters.m_RecurrentToForgetWeights != nullptr)
361     {
362         ConstTensor recurrentToForgetWeightsTensorCopy(
363                 m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(),
364                 m_BasicParameters.m_RecurrentToForgetWeights->Map(true));
365         recurrentToForgetWeightsTensor = recurrentToForgetWeightsTensorCopy;
366         inputParams.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
367     }
368
369     ConstTensor recurrentToCellWeightsTensor;
370     if (m_BasicParameters.m_RecurrentToCellWeights != nullptr)
371     {
372         ConstTensor recurrentToCellWeightsTensorCopy(
373                 m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(),
374                 m_BasicParameters.m_RecurrentToCellWeights->Map(true));
375         recurrentToCellWeightsTensor = recurrentToCellWeightsTensorCopy;
376         inputParams.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
377     }
378
379     ConstTensor recurrentToOutputWeightsTensor;
380     if (m_BasicParameters.m_RecurrentToOutputWeights != nullptr)
381     {
382         ConstTensor recurrentToOutputWeightsTensorCopy(
383                 m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(),
384                 m_BasicParameters.m_RecurrentToOutputWeights->Map(true));
385         recurrentToOutputWeightsTensor = recurrentToOutputWeightsTensorCopy;
386         inputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
387     }
388
389     ConstTensor cellToInputWeightsTensor;
390     if (m_PeepholeParameters.m_CellToInputWeights != nullptr)
391     {
392         ConstTensor cellToInputWeightsTensorCopy(m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
393                                                  m_PeepholeParameters.m_CellToInputWeights->Map(true));
394         cellToInputWeightsTensor = cellToInputWeightsTensorCopy;
395         inputParams.m_CellToInputWeights = &cellToInputWeightsTensor;
396     }
397
398     ConstTensor cellToForgetWeightsTensor;
399     if (m_PeepholeParameters.m_CellToForgetWeights != nullptr)
400     {
401         ConstTensor cellToForgetWeightsTensorCopy(m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(),
402                                                   m_PeepholeParameters.m_CellToForgetWeights->Map(true));
403         cellToForgetWeightsTensor = cellToForgetWeightsTensorCopy;
404         inputParams.m_CellToForgetWeights = &cellToForgetWeightsTensor;
405     }
406
407     ConstTensor cellToOutputWeightsTensor;
408     if (m_PeepholeParameters.m_CellToOutputWeights != nullptr)
409     {
410         ConstTensor cellToOutputWeightsTensorCopy(m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(),
411                                                   m_PeepholeParameters.m_CellToOutputWeights->Map(true));
412         cellToOutputWeightsTensor = cellToOutputWeightsTensorCopy;
413         inputParams.m_CellToOutputWeights = &cellToOutputWeightsTensor;
414     }
415
416     ConstTensor inputGateBiasTensor;
417     if (m_CifgParameters.m_InputGateBias != nullptr)
418     {
419         ConstTensor inputGateBiasTensorCopy(m_CifgParameters.m_InputGateBias->GetTensorInfo(),
420                                             m_CifgParameters.m_InputGateBias->Map(true));
421         inputGateBiasTensor = inputGateBiasTensorCopy;
422         inputParams.m_InputGateBias = &inputGateBiasTensor;
423     }
424
425     ConstTensor forgetGateBiasTensor;
426     if (m_BasicParameters.m_ForgetGateBias != nullptr)
427     {
428         ConstTensor forgetGateBiasTensorCopy(m_BasicParameters.m_ForgetGateBias->GetTensorInfo(),
429                                              m_BasicParameters.m_ForgetGateBias->Map(true));
430         forgetGateBiasTensor = forgetGateBiasTensorCopy;
431         inputParams.m_ForgetGateBias = &forgetGateBiasTensor;
432     }
433
434     ConstTensor cellBiasTensor;
435     if (m_BasicParameters.m_CellBias != nullptr)
436     {
437         ConstTensor cellBiasTensorCopy(m_BasicParameters.m_CellBias->GetTensorInfo(),
438                                        m_BasicParameters.m_CellBias->Map(true));
439         cellBiasTensor = cellBiasTensorCopy;
440         inputParams.m_CellBias = &cellBiasTensor;
441     }
442
443     ConstTensor outputGateBias;
444     if (m_BasicParameters.m_OutputGateBias != nullptr)
445     {
446         ConstTensor outputGateBiasCopy(m_BasicParameters.m_OutputGateBias->GetTensorInfo(),
447                                        m_BasicParameters.m_OutputGateBias->Map(true));
448         outputGateBias = outputGateBiasCopy;
449         inputParams.m_OutputGateBias = &outputGateBias;
450     }
451
452     ConstTensor projectionWeightsTensor;
453     if (m_ProjectionParameters.m_ProjectionWeights != nullptr)
454     {
455         ConstTensor projectionWeightsTensorCopy(m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(),
456                                                 m_ProjectionParameters.m_ProjectionWeights->Map(true));
457         projectionWeightsTensor = projectionWeightsTensorCopy;
458         inputParams.m_ProjectionWeights = &projectionWeightsTensor;
459     }
460
461     ConstTensor projectionBiasTensor;
462     if (m_ProjectionParameters.m_ProjectionBias != nullptr)
463     {
464         ConstTensor projectionBiasTensorCopy(m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(),
465                                              m_ProjectionParameters.m_ProjectionBias->Map(true));
466         projectionBiasTensor = projectionBiasTensorCopy;
467         inputParams.m_ProjectionBias = &projectionBiasTensor;
468     }
469
470     ConstTensor inputLayerNormTensor;
471     if (m_LayerNormParameters.m_InputLayerNormWeights != nullptr)
472     {
473         ConstTensor inputLayerNormTensorCopy(m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(),
474                                              m_LayerNormParameters.m_InputLayerNormWeights->Map(true));
475         inputLayerNormTensor = inputLayerNormTensorCopy;
476         inputParams.m_InputLayerNormWeights = &inputLayerNormTensor;
477     }
478
479     ConstTensor forgetLayerNormTensor;
480     if (m_LayerNormParameters.m_ForgetLayerNormWeights != nullptr)
481     {
482         ConstTensor forgetLayerNormTensorCopy(m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(),
483                                               m_LayerNormParameters.m_ForgetLayerNormWeights->Map(true));
484         forgetLayerNormTensor = forgetLayerNormTensorCopy;
485         inputParams.m_ForgetLayerNormWeights = &forgetLayerNormTensor;
486     }
487
488     ConstTensor cellLayerNormTensor;
489     if (m_LayerNormParameters.m_CellLayerNormWeights != nullptr)
490     {
491         ConstTensor cellLayerNormTensorCopy(m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(),
492                                             m_LayerNormParameters.m_CellLayerNormWeights->Map(true));
493         cellLayerNormTensor = cellLayerNormTensorCopy;
494         inputParams.m_CellLayerNormWeights = &cellLayerNormTensor;
495     }
496
497     ConstTensor outputLayerNormTensor;
498     if (m_LayerNormParameters.m_OutputLayerNormWeights != nullptr)
499     {
500         ConstTensor outputLayerNormTensorCopy(m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(),
501                                               m_LayerNormParameters.m_OutputLayerNormWeights->Map(true));
502         outputLayerNormTensor = outputLayerNormTensorCopy;
503         inputParams.m_OutputLayerNormWeights = &outputLayerNormTensor;
504     }
505
506
507     visitor.VisitQLstmLayer(this, GetParameters(), inputParams, GetName());
508 }
509
510 } // namespace armnn