2 * Copyright (C) 2019 Samsung Electronics Co., Ltd. All Rights Reserved.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 * http://www.apache.org/licenses/LICENSE-2.0
8 * Unless required by applicable law or agreed to in writing, software
9 * distributed under the License is distributed on an "AS IS" BASIS,
10 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 * See the License for the specific language governing permissions and
12 * limitations under the License.
16 * @date 04 December 2019
17 * @see https://github.com/nnstreamer/nntrainer
18 * @author Jijoong Moon <jijoong.moon@samsung.com>
19 * @bug No known bugs except for NYI items
20 * @brief This is DeepQ Reinforcement Learning Example
21 * Environment : CartPole-v0 ( from Native or Open AI / Gym )
22 * Support Experience Replay to remove data co-relation
23 * To maintain stability, two Neural Net are used ( mainNN,
27 * +---------------------+ +----------+
28 * | Initialization |------------->| |
29 * +---------------------+ | |
31 * +------->+-----------------------+ | |
32 * | +--->| Get Action from Q Net | | |
33 * | | +-----------------------+ | |
35 * | | +---------------------+ | Env |
36 * | | | Put Action |------------->| |
37 * | | +---------------------+ | |
39 * | | +---------------------+ | |
40 * | | | Get State |<-------------| |
41 * | | +---------------------+ | |
43 * | | +------------------------+ | |
44 * | | | Set Penalty & Updaet Q | | |
45 * | | | from Target Network | | |
46 * | | +------------------------+ | |
48 * | | +-----------------------+ | |
49 * | +----| Put Experience Buffer | | |
50 * | +-----------------------+ | |
52 * | +------------------------+ | |
53 * | | Training Q Network | | |
54 * | | with batch_size | | |
55 * | +------------------------+ | |
57 * | +------------------------+ | |
58 * | | copy main Net to | | |
59 * +--------| Target Net | | |
60 * +------------------------+ +----------+
64 #include "neuralnet.h"
75 #include "include/gym/gym.h"
76 #define STATE Gym::State
77 #define ENV Gym::Environment
78 #define PTR boost::shared_ptr<ENV>
80 #include "CartPole/cartpole.h"
81 #define STATE Env::State
82 #define ENV Env::CartPole
83 #define PTR std::shared_ptr<ENV>
87 * @brief Maximum episodes to run
89 #define MAX_EPISODES 300
92 * @brief boolean to reder (only works for openAI/Gym)
97 * @brief Max Number of data in Replay Queue
99 #define REPLAY_MEMORY 50000
102 * @brief minibach size
104 #define BATCH_SIZE 30
107 * @brief discount factor
112 * @brief if true : update else : forward propagation
114 #define TRAINING true
117 * @brief Experience data Type to store experience buffer
121 std::vector<float> action;
130 * @brief Generate Random double value between min to max
131 * @param[in] min : minimum value
132 * @param[in] max : maximum value
133 * @retval min < random value < max
135 static float RandomFloat(float Min, float Max) {
136 float r = Min + static_cast<float>(rand_r(&seed)) /
137 (static_cast<float>(RAND_MAX) / (Max - Min));
142 * @brief Generate Random integer value between min to max
143 * @param[in] min : minimum value
144 * @param[in] max : maximum value
145 * @retval min < random value < max
147 static int rangeRandom(int Min, int Max) {
148 int n = Max - Min + 1;
149 int remainder = RAND_MAX % n;
153 } while (x >= RAND_MAX - remainder);
158 * @brief Generate randomly selected Experience buffer from
159 * Experience Replay Queue which number is equal batch_size
160 * @param[in] Q Experience Replay Queue
161 * @retval Experience vector
163 static std::vector<Experience> getBatchSizeData(std::deque<Experience> Q) {
164 int Max = (BATCH_SIZE > Q.size()) ? BATCH_SIZE : Q.size();
165 int Min = (BATCH_SIZE < Q.size()) ? BATCH_SIZE : Q.size();
167 std::vector<bool> duplicate;
168 std::vector<int> mem;
169 std::vector<Experience> in_Exp;
172 duplicate.resize(Max);
174 for (int i = 0; i < Max; i++)
175 duplicate[i] = false;
177 while (count < Min) {
178 int nomi = rangeRandom(0, Q.size() - 1);
179 if (!duplicate[nomi]) {
181 duplicate[nomi] = true;
186 for (int i = 0; i < Min; i++) {
187 in_Exp.push_back(Q[mem[i]]);
194 * @brief Calculate argmax
195 * @param[in] vec input to calculate argmax
198 static int argmax(std::vector<float> vec) {
201 for (unsigned int i = 0; i < vec.size(); i++) {
211 * @brief Create & initialize environment
212 * @param[in] input_size State Size : 4 for cartpole-v0
213 * @param[in] output_size Action Size : 2 for cartpole-v0
214 * @retval Env object pointer
216 static PTR init_environment(int &input_size, int &output_size) {
218 boost::shared_ptr<Gym::Client> client;
219 std::string env_id = "CartPole-v0";
221 client = Gym::client_create("127.0.0.1", 5000);
222 } catch (const std::exception &e) {
223 fprintf(stderr, "ERROR: %s\n", e.what());
227 boost::shared_ptr<ENV> env = client->make(env_id);
228 boost::shared_ptr<Gym::Space> action_space = env->action_space();
229 boost::shared_ptr<Gym::Space> observation_space = env->observation_space();
231 input_size = observation_space->sample().size();
233 output_size = action_space->discreet_n;
235 std::shared_ptr<ENV> env(new ENV);
237 input_size = env->getInputSize();
238 output_size = env->getOutputSize();
245 * @brief Calculate DeepQ
246 * @param[in] arg 1 : configuration file path
248 int main(int argc, char **argv) {
250 std::cout << "./DeepQ Config.ini\n";
253 const std::string weight_file = "model_deepq.bin";
254 const std::vector<std::string> args(argv + 1, argv + argc);
255 std::string config = args[0];
257 std::string filepath = "debug.txt";
258 std::ofstream writeFile(filepath.data());
260 if (!writeFile.is_open()) {
261 std::cout << "Error opening file" << std::endl;
266 std::deque<Experience> expQ;
271 * @brief Initialize Environment
273 int input_size, output_size;
274 env = init_environment(input_size, output_size);
275 printf("input_size %d, output_size %d\n", input_size, output_size);
278 * @brief Create mainNet & Target Net
280 nntrainer::NeuralNetwork mainNet;
281 nntrainer::NeuralNetwork targetNet;
284 mainNet.loadFromConfig(config);
286 mainNet.initialize();
287 targetNet.loadFromConfig(config);
289 targetNet.initialize();
291 std::cerr << "Error during init" << std::endl;
296 * @brief Read Model Data if any
299 mainNet.load(weight_file, ml::train::ModelFormat::MODEL_FORMAT_BIN);
300 targetNet.load(weight_file, ml::train::ModelFormat::MODEL_FORMAT_BIN);
302 std::cerr << "Error during readBin\n";
309 for (int episode = 0; episode < MAX_EPISODES; episode++) {
310 float epsilon = 1. / ((episode / 10) + 1);
320 * @brief Do until the end of episode
323 std::vector<float> action;
324 float r = RandomFloat(0.0, 1.0);
326 if (r < epsilon && TRAINING) {
328 boost::shared_ptr<Gym::Space> action_space = env->action_space();
329 action = action_space->sample();
331 action = env->sample();
333 std::cout << "test result random action : " << action[0] << "\n";
335 std::vector<float> input(s.observation.begin(), s.observation.end());
337 * @brief get action with input State with mainNet
339 nntrainer::Tensor in_tensor;
340 nntrainer::sharedConstTensor test;
342 in_tensor = nntrainer::Tensor(
343 {input}, {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16});
345 std::cerr << "Error while construct tensor" << std::endl;
349 test = mainNet.forwarding({MAKE_SHARED_TENSOR(in_tensor)})[0];
351 std::cerr << "Error while forwarding the network" << std::endl;
354 const float *data = test->getData();
355 unsigned int len = test->getDim().getDataLen();
356 std::vector<float> temp(data, data + len);
357 action.push_back(argmax(temp));
359 std::cout << "qvalues : [";
361 std::cout << temp[0] << "][";
363 std::cout << temp[1] << "] : ACTION (argmax) = ";
365 std::cout << argmax(temp) << "\n";
369 * @brief step Env with this action & save next State in next_s
371 env->step(action, RENDER, &next_s);
375 ex.reward = next_s.reward;
376 ex.next_state = next_s;
377 ex.done = next_s.done;
379 if (expQ.size() > REPLAY_MEMORY) {
386 * @brief Set Penalty or reward
389 std::cout << "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! DONE : Episode "
390 << episode << " Iteration : " << step_count << "\n";
397 * @brief Save at the Experience Replay Buffer
404 if (step_count > 10000) {
405 std::cout << "step_count is over 10000\n";
409 if (step_count > 10000)
412 if (!TRAINING && done)
416 * @brief Training after finishing 10 episodes
418 if (episode % 10 == 1 && TRAINING) {
419 for (int iter = 0; iter < 50; iter++) {
421 * @brief Get batch size of Experience
423 std::vector<Experience> in_Exp = getBatchSizeData(expQ);
424 std::vector<std::vector<std::vector<std::vector<float>>>> inbatch;
425 std::vector<std::vector<std::vector<std::vector<float>>>> next_inbatch;
428 * @brief Generate Lable with next state
430 for (unsigned int i = 0; i < in_Exp.size(); i++) {
431 STATE state = in_Exp[i].state;
432 STATE next_state = in_Exp[i].next_state;
433 std::vector<float> in(state.observation.begin(),
434 state.observation.end());
435 inbatch.push_back({{in}});
437 std::vector<float> next_in(next_state.observation.begin(),
438 next_state.observation.end());
439 next_inbatch.push_back({{next_in}});
442 nntrainer::Tensor q_in, nq_in;
444 q_in = nntrainer::Tensor(
445 inbatch, {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16});
446 nq_in = nntrainer::Tensor(next_inbatch, {nntrainer::Tformat::NCHW,
447 nntrainer::Tdatatype::FP16});
449 std::cerr << "Error during tensor constructino" << std::endl;
454 * @brief run forward propagation with mainNet
456 nntrainer::sharedConstTensor Q;
458 Q = mainNet.forwarding({MAKE_SHARED_TENSOR(q_in)})[0];
460 std::cerr << "Error during forwarding main network" << std::endl;
465 * @brief run forward propagation with targetNet
467 nntrainer::sharedConstTensor NQ;
469 NQ = targetNet.forwarding({MAKE_SHARED_TENSOR(nq_in)})[0];
471 std::cerr << "Error during forwarding target network" << std::endl;
474 const float *nqa = NQ->getData();
477 * @brief Update Q values & udpate mainNetwork
479 nntrainer::Tensor tempQ = *Q;
480 for (unsigned int i = 0; i < in_Exp.size(); i++) {
481 if (in_Exp[i].done) {
483 tempQ.setValue(i, 0, 0, (int)in_Exp[i].action[0],
484 (float)in_Exp[i].reward);
486 std::cerr << "Error during set a value" << std::endl;
490 float next = (nqa[i * NQ->width()] > nqa[i * NQ->width() + 1])
491 ? nqa[i * NQ->width()]
492 : nqa[i * NQ->width() + 1];
494 tempQ.setValue(i, 0, 0, (int)in_Exp[i].action[0],
495 (float)in_Exp[i].reward + DISCOUNT * next);
497 std::cerr << "Error during set value" << std::endl;
502 nntrainer::Tensor in_tensor;
504 in_tensor = nntrainer::Tensor(
505 inbatch, {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16});
506 mainNet.forwarding({MAKE_SHARED_TENSOR(in_tensor)}, {Q});
507 mainNet.backwarding(iter);
509 std::cerr << "Error during backwarding the network" << std::endl;
515 writeFile << "mainNet Loss : " << mainNet.getLoss()
516 << " : targetNet Loss : " << targetNet.getLoss() << "\n";
517 std::cout << "\n\n =================== TRAINIG & COPY NET "
518 "==================\n\n";
519 std::cout << "mainNet Loss : ";
521 std::cout << mainNet.getLoss() << "\n targetNet Loss : ";
523 std::cout << targetNet.getLoss() << "\n\n";
524 } catch (std::exception &e) {
525 std::cerr << "Error during getLoss: " << e.what() << "\n";
530 targetNet.load(weight_file, ml::train::ModelFormat::MODEL_FORMAT_BIN);
531 mainNet.save(weight_file, ml::train::ModelFormat::MODEL_FORMAT_BIN);
532 } catch (std::exception &e) {
533 std::cerr << "Error during saveBin: " << e.what() << "\n";