Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / cldnn_engine / simple_math.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "simple_math.h"
6 #include <cctype>
7 #include <string>
8 #include <set>
9 #include <stack>
10 #include <map>
11 #include <stdexcept>
12
13 // Using the algorithm from: https://en.wikipedia.org/wiki/Shunting-yard_algorithm
14
15 const std::set<char> SimpleMathExpression::whitespaces = {
16     ' ',
17     '\t',
18 };
19 const std::map<char, SimpleMathExpression::Operator> SimpleMathExpression::operators = {
20     { '+', { 0, std::plus<int>() } },
21     { '-', { 0, std::minus<int>() } },
22     { '*', { 1, std::multiplies<int>() } },
23     { '/', { 1, std::divides<int>()  } },
24     { '%', { 1, std::modulus<int>()  } },
25 };
26
27 void SimpleMathExpression::SetVariables(const std::map<char, int>& vars) {
28     m_variables = vars;
29 }
30
31 bool SimpleMathExpression::SetExpression(const std::string & expression) {
32     m_expression = expression;
33     m_parsed = Parse();
34     return m_parsed;
35 }
36
37 int SimpleMathExpression::Evaluate() const {
38     if (!m_parsed) {
39         throw std::runtime_error("Evaluation error: not parsed yet");
40     }
41
42     std::stack<int> values;
43     for (Token t : m_parsedTokens) {
44         switch (t.type) {
45         case Token::Value:
46             values.push(t.value);
47             break;
48         case Token::Operator: {
49             if (values.size() < 2) {
50                 throw std::runtime_error("Illegal expression: not enough values for operator evaluation");
51             }
52             // pop last 2 values and apply operand
53             int val2 = values.top();
54             values.pop();
55             int val1 = values.top();
56             values.pop();
57             values.push(operators.at(t.op).second(val1, val2));
58         }
59             break;
60         default:
61             throw std::runtime_error("Illegal expression: unhandled token");
62         }
63     }
64     if (values.size() != 1) {
65         throw std::runtime_error("Illegal expression: not enough operators");
66     }
67     return values.top();
68 }
69
70 bool SimpleMathExpression::Parse() {
71     std::stack<char> operatorStack;
72     m_parsedTokens.clear();
73
74     // while there are tokens to be read:
75     for (size_t i = 0; i != m_expression.length(); i++) {
76         //  read a token.
77         while (whitespaces.find(m_expression.at(i)) != whitespaces.end()) i++;  // ignore whitespaces
78         char curr = m_expression.at(i);
79
80         //  if the token is a number, then push it to the output queue.
81         if (isdigit(curr)) {
82             size_t len;
83             int value = std::stoi(m_expression.substr(i), &len);
84             m_parsedTokens.push_back(Token(Token::Value, value, 0));
85             i += (len - 1);
86             continue;
87         }
88
89         //  if the token is a variable, then push it's value to the output queue.
90         if (m_variables.find(curr) != m_variables.end()) {
91             m_parsedTokens.push_back(Token(Token::Value, m_variables.at(curr), 0));
92             continue;
93         }
94
95         //  if the token is an operator, then:
96         if (operators.find(curr) != operators.end()) {
97         //    while there is an operator at the top of the operator stack with
98         //      greater than or equal to precedence:
99         //        pop operators from the operator stack, onto the output queue;
100             while ( !operatorStack.empty() &&
101                     (operators.find(operatorStack.top()) != operators.end()) &&
102                     (operators.at(operatorStack.top()).first >= operators.at(curr).first)) {
103                 char op = operatorStack.top();
104                 operatorStack.pop();
105                 m_parsedTokens.push_back(Token(Token::Operator, 0, op));
106             }
107         //      push the read operator onto the operator stack.
108             operatorStack.push(curr);
109             continue;
110         }
111
112         //  if the token is a left bracket (i.e. "("), then:
113         //    push it onto the operator stack.
114         if (curr == '(') {
115             operatorStack.push(curr);
116             continue;
117         }
118
119         //  if the token is a right bracket (i.e. ")"), then:
120         if (curr == ')') {
121             //    while the operator at the top of the operator stack is not a left bracket:
122             //      pop operators from the operator stack onto the output queue.
123             while (!operatorStack.empty() && operatorStack.top() != '(') {
124                 m_parsedTokens.push_back(Token(Token::Operator, 0, operatorStack.top()));
125                 operatorStack.pop();
126             }
127             //    pop the left bracket from the stack.
128             //    /* if the stack runs out without finding a left bracket, then there are
129             //       mismatched parentheses. */
130             if (!operatorStack.empty() && operatorStack.top() == '(') {
131                 operatorStack.pop();
132             } else {
133                 return false;
134             }
135             continue;
136         }
137
138         // unknown token
139         return false;
140     }
141     // if there are no more tokens to read:
142     //  while there are still operator tokens on the stack:
143     //    /* if the operator token on the top of the stack is a bracket, then
144     //      there are mismatched parentheses. */
145     //    pop the operator onto the output queue.
146     while (!operatorStack.empty()) {
147         if (operatorStack.top() == '(') {
148             return false;
149         }
150         m_parsedTokens.push_back(Token(Token::Operator, 0, operatorStack.top()));
151         operatorStack.pop();
152     }
153
154     // exit.
155     m_parsed = true;
156     return true;
157 }