arm_compute v18.05
[platform/upstream/armcl.git] / src / graph / backends / GLES / GCNodeValidator.cpp
1 /*
2  * Copyright (c) 2018 ARM Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/graph/backends/GLES/GCNodeValidator.h"
25
26 #include "arm_compute/graph/backends/ValidateHelpers.h"
27 #include "arm_compute/graph/nodes/Nodes.h"
28
29 #include "arm_compute/core/utils/misc/Cast.h"
30 #include "arm_compute/runtime/GLES_COMPUTE/GCFunctions.h"
31
32 using namespace arm_compute::utils::cast;
33
34 namespace arm_compute
35 {
36 namespace graph
37 {
38 namespace backends
39 {
40 namespace
41 {
42 /** Validates a Depthwise Convolution layer node
43  *
44  * @param[in] node Node to validate
45  *
46  * @return Status
47  */
48 Status validate_depthwise_convolution_layer(DepthwiseConvolutionLayerNode &node)
49 {
50     ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating GCDepthwiseConvolutionLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
51     ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 3);
52     ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
53
54     // Extract IO and info
55     arm_compute::ITensorInfo *weights = detail::get_backing_tensor_info(node.input(1));
56     ARM_COMPUTE_ERROR_ON(weights == nullptr);
57
58     // Validate function
59     ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights->tensor_shape().x() != 3 && weights->tensor_shape().y() != 3, "Unsupported depthwise convolution");
60     node.set_depthwise_convolution_method(DepthwiseConvolutionMethod::OPTIMIZED_3x3);
61
62     return Status{};
63 }
64 /** Validates a Convolution layer node
65  *
66  * @param[in] node Node to validate
67  *
68  * @return Status
69  */
70 Status validate_convolution_layer(ConvolutionLayerNode &node)
71 {
72     ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating ConvolutionLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
73     ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 3);
74     ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
75
76     // Extract IO and info
77     arm_compute::ITensorInfo *weights        = detail::get_backing_tensor_info(node.input(1));
78     const PadStrideInfo       conv_info      = node.convolution_info();
79     const ConvolutionMethod   conv_algorithm = node.convolution_method();
80
81     // Validate function
82     if(conv_algorithm == ConvolutionMethod::DIRECT)
83     {
84         bool is_square         = weights->tensor_shape().x() == weights->tensor_shape().y();
85         bool is_direct         = (weights->tensor_shape().x() == 1) || (weights->tensor_shape().x() == 3) || (weights->tensor_shape().x() == 5);
86         bool is_correct_stride = (conv_info.stride().first) <= 2 && (conv_info.stride().second <= 2);
87         if(!(is_square && is_direct && is_correct_stride))
88         {
89             node.set_convolution_method(ConvolutionMethod::DEFAULT);
90         }
91     }
92
93     return Status{};
94 }
95 } // namespace
96
97 Status GCNodeValidator::validate(INode *node)
98 {
99     if(node == nullptr)
100     {
101         return Status{};
102     }
103
104     NodeType type = node->type();
105     switch(type)
106     {
107         case NodeType::ConvolutionLayer:
108             return validate_convolution_layer(*polymorphic_downcast<ConvolutionLayerNode *>(node));
109         case NodeType::DepthwiseConvolutionLayer:
110             return validate_depthwise_convolution_layer(*polymorphic_downcast<DepthwiseConvolutionLayerNode *>(node));
111         case NodeType::FlattenLayer:
112             return ARM_COMPUTE_CREATE_ERROR(arm_compute::ErrorCode::RUNTIME_ERROR, "Unsupported operation");
113         case NodeType::ReshapeLayer:
114             return ARM_COMPUTE_CREATE_ERROR(arm_compute::ErrorCode::RUNTIME_ERROR, "Unsupported operation");
115         default:
116             return Status{};
117     }
118 }
119 } // namespace backends
120 } // namespace graph
121 } // namespace arm_compute