f16f13836186c4d3cea9826d7b092423a1b1820a
[platform/core/ml/nntrainer.git] / Applications / ReinforcementLearning / DeepQ / jni / main.cpp
1 /**
2  * Copyright (C) 2019 Samsung Electronics Co., Ltd. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *   http://www.apache.org/licenses/LICENSE-2.0
8  * Unless required by applicable law or agreed to in writing, software
9  * distributed under the License is distributed on an "AS IS" BASIS,
10  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11  * See the License for the specific language governing permissions and
12  * limitations under the License.
13  *
14  *
15  * @file        main.cpp
16  * @date        04 December 2019
17  * @see         https://github.com/nnstreamer/nntrainer
18  * @author      Jijoong Moon <jijoong.moon@samsung.com>
19  * @bug         No known bugs except for NYI items
20  * @brief       This is DeepQ Reinforcement Learning Example
21  *              Environment : CartPole-v0 ( from Native or Open AI / Gym )
22  *              Support Experience Replay to remove data co-relation
23  *              To maintain stability, two Neural Net are used ( mainNN,
24  * targetNN )
25  *
26  *
27  *                  +---------------------+              +----------+
28  *                  |    Initialization   |------------->|          |
29  *                  +---------------------+              |          |
30  *                             |                         |          |
31  *        +------->+-----------------------+             |          |
32  *        |   +--->| Get Action from Q Net |             |          |
33  *        |   |    +-----------------------+             |          |
34  *        |   |                |                         |          |
35  *        |   |     +---------------------+              |   Env    |
36  *        |   |     |      Put Action     |------------->|          |
37  *        |   |     +---------------------+              |          |
38  *        |   |                |                         |          |
39  *        |   |     +---------------------+              |          |
40  *        |   |     |      Get State      |<-------------|          |
41  *        |   |     +---------------------+              |          |
42  *        |   |                |                         |          |
43  *        |   |    +------------------------+            |          |
44  *        |   |    | Set Penalty & Updaet Q |            |          |
45  *        |   |    | from Target Network    |            |          |
46  *        |   |    +------------------------+            |          |
47  *        |   |                |                         |          |
48  *        |   |    +-----------------------+             |          |
49  *        |   +----| Put Experience Buffer |             |          |
50  *        |        +-----------------------+             |          |
51  *        |                    |                         |          |
52  *        |        +------------------------+            |          |
53  *        |        |  Training Q Network    |            |          |
54  *        |        |     with batch_size     |            |          |
55  *        |        +------------------------+            |          |
56  *        |                    |                         |          |
57  *        |        +------------------------+            |          |
58  *        |        |    copy main Net to    |            |          |
59  *        +--------|     Target Net         |            |          |
60  *                 +------------------------+            +----------+
61  *
62  */
63
64 #include "neuralnet.h"
65 #include "tensor.h"
66 #include <fstream>
67 #include <iostream>
68 #include <iterator>
69 #include <memory>
70 #include <queue>
71 #include <stdio.h>
72 #include <unistd.h>
73
74 #ifdef USE_GYM
75 #include "include/gym/gym.h"
76 #define STATE Gym::State
77 #define ENV Gym::Environment
78 #define PTR boost::shared_ptr<ENV>
79 #else
80 #include "CartPole/cartpole.h"
81 #define STATE Env::State
82 #define ENV Env::CartPole
83 #define PTR std::shared_ptr<ENV>
84 #endif
85
86 /**
87  * @brief     Maximum episodes to run
88  */
89 #define MAX_EPISODES 300
90
91 /**
92  * @brief     boolean to reder (only works for openAI/Gym)
93  */
94 #define RENDER true
95
96 /**
97  * @brief     Max Number of data in Replay Queue
98  */
99 #define REPLAY_MEMORY 50000
100
101 /**
102  * @brief     minibach size
103  */
104 #define BATCH_SIZE 30
105
106 /**
107  * @brief     discount factor
108  */
109 #define DISCOUNT 0.9
110
111 /**
112  * @brief     if true : update else : forward propagation
113  */
114 #define TRAINING true
115
116 /**
117  * @brief     Experience data Type to store experience buffer
118  */
119 typedef struct {
120   STATE state;
121   std::vector<float> action;
122   float reward;
123   STATE next_state;
124   bool done;
125 } Experience;
126
127 unsigned int seed;
128
129 /**
130  * @brief     Generate Random double value between min to max
131  * @param[in] min : minimum value
132  * @param[in] max : maximum value
133  * @retval    min < random value < max
134  */
135 static float RandomFloat(float Min, float Max) {
136   float r = Min + static_cast<float>(rand_r(&seed)) /
137                     (static_cast<float>(RAND_MAX) / (Max - Min));
138   return r;
139 }
140
141 /**
142  * @brief     Generate Random integer value between min to max
143  * @param[in] min : minimum value
144  * @param[in] max : maximum value
145  * @retval    min < random value < max
146  */
147 static int rangeRandom(int Min, int Max) {
148   int n = Max - Min + 1;
149   int remainder = RAND_MAX % n;
150   int x;
151   do {
152     x = rand_r(&seed);
153   } while (x >= RAND_MAX - remainder);
154   return Min + x % n;
155 }
156
157 /**
158  * @brief     Generate randomly selected Experience buffer from
159  *            Experience Replay Queue which number is equal batch_size
160  * @param[in] Q Experience Replay Queue
161  * @retval    Experience vector
162  */
163 static std::vector<Experience> getBatchSizeData(std::deque<Experience> Q) {
164   int Max = (BATCH_SIZE > Q.size()) ? BATCH_SIZE : Q.size();
165   int Min = (BATCH_SIZE < Q.size()) ? BATCH_SIZE : Q.size();
166
167   std::vector<bool> duplicate;
168   std::vector<int> mem;
169   std::vector<Experience> in_Exp;
170   int count = 0;
171
172   duplicate.resize(Max);
173
174   for (int i = 0; i < Max; i++)
175     duplicate[i] = false;
176
177   while (count < Min) {
178     int nomi = rangeRandom(0, Q.size() - 1);
179     if (!duplicate[nomi]) {
180       mem.push_back(nomi);
181       duplicate[nomi] = true;
182       count++;
183     }
184   }
185
186   for (int i = 0; i < Min; i++) {
187     in_Exp.push_back(Q[mem[i]]);
188   }
189
190   return in_Exp;
191 }
192
193 /**
194  * @brief     Calculate argmax
195  * @param[in] vec input to calculate argmax
196  * @retval argmax
197  */
198 static int argmax(std::vector<float> vec) {
199   int ret = 0;
200   float val = 0.0;
201   for (unsigned int i = 0; i < vec.size(); i++) {
202     if (val < vec[i]) {
203       val = vec[i];
204       ret = i;
205     }
206   }
207   return ret;
208 }
209
210 /**
211  * @brief     Create & initialize environment
212  * @param[in] input_size State Size : 4 for cartpole-v0
213  * @param[in] output_size Action Size : 2 for cartpole-v0
214  * @retval Env object pointer
215  */
216 static PTR init_environment(int &input_size, int &output_size) {
217 #ifdef USE_GYM
218   boost::shared_ptr<Gym::Client> client;
219   std::string env_id = "CartPole-v0";
220   try {
221     client = Gym::client_create("127.0.0.1", 5000);
222   } catch (const std::exception &e) {
223     fprintf(stderr, "ERROR: %s\n", e.what());
224     return NULL;
225   }
226
227   boost::shared_ptr<ENV> env = client->make(env_id);
228   boost::shared_ptr<Gym::Space> action_space = env->action_space();
229   boost::shared_ptr<Gym::Space> observation_space = env->observation_space();
230
231   input_size = observation_space->sample().size();
232
233   output_size = action_space->discreet_n;
234 #else
235   std::shared_ptr<ENV> env(new ENV);
236   env->init();
237   input_size = env->getInputSize();
238   output_size = env->getOutputSize();
239 #endif
240
241   return env;
242 }
243
244 /**
245  * @brief     Calculate DeepQ
246  * @param[in]  arg 1 : configuration file path
247  */
248 int main(int argc, char **argv) {
249   if (argc < 2) {
250     std::cout << "./DeepQ Config.ini\n";
251     exit(0);
252   }
253   const std::string weight_file = "model_deepq.bin";
254   const std::vector<std::string> args(argv + 1, argv + argc);
255   std::string config = args[0];
256
257   std::string filepath = "debug.txt";
258   std::ofstream writeFile(filepath.data());
259
260   if (!writeFile.is_open()) {
261     std::cout << "Error opening file" << std::endl;
262     return 0;
263   };
264   seed = time(NULL);
265   srand(seed);
266   std::deque<Experience> expQ;
267
268   PTR env;
269
270   /**
271    * @brief     Initialize Environment
272    */
273   int input_size, output_size;
274   env = init_environment(input_size, output_size);
275   printf("input_size %d, output_size %d\n", input_size, output_size);
276
277   /**
278    * @brief     Create mainNet & Target Net
279    */
280   nntrainer::NeuralNetwork mainNet;
281   nntrainer::NeuralNetwork targetNet;
282
283   try {
284     mainNet.loadFromConfig(config);
285     mainNet.compile();
286     mainNet.initialize();
287     targetNet.loadFromConfig(config);
288     targetNet.compile();
289     targetNet.initialize();
290   } catch (...) {
291     std::cerr << "Error during init" << std::endl;
292     return 0;
293   }
294
295   /**
296    * @brief     Read Model Data if any
297    */
298   try {
299     mainNet.load(weight_file, ml::train::ModelFormat::MODEL_FORMAT_BIN);
300     targetNet.load(weight_file, ml::train::ModelFormat::MODEL_FORMAT_BIN);
301   } catch (...) {
302     std::cerr << "Error during readBin\n";
303     return 1;
304   }
305
306   /**
307    * @brief     Run Episode
308    */
309   for (int episode = 0; episode < MAX_EPISODES; episode++) {
310     float epsilon = 1. / ((episode / 10) + 1);
311     bool done = false;
312     int step_count = 0;
313     STATE s;
314     STATE next_s;
315     s.done = false;
316     next_s.done = false;
317     env->reset(&s);
318
319     /**
320      * @brief     Do until the end of episode
321      */
322     while (!done) {
323       std::vector<float> action;
324       float r = RandomFloat(0.0, 1.0);
325
326       if (r < epsilon && TRAINING) {
327 #ifdef USE_GYM
328         boost::shared_ptr<Gym::Space> action_space = env->action_space();
329         action = action_space->sample();
330 #else
331         action = env->sample();
332 #endif
333         std::cout << "test result random action : " << action[0] << "\n";
334       } else {
335         std::vector<float> input(s.observation.begin(), s.observation.end());
336         /**
337          * @brief     get action with input State with mainNet
338          */
339         nntrainer::Tensor in_tensor;
340         nntrainer::sharedConstTensor test;
341         try {
342           in_tensor = nntrainer::Tensor(
343             {input}, {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16});
344         } catch (...) {
345           std::cerr << "Error while construct tensor" << std::endl;
346           return 0;
347         }
348         try {
349           test = mainNet.forwarding({MAKE_SHARED_TENSOR(in_tensor)})[0];
350         } catch (...) {
351           std::cerr << "Error while forwarding the network" << std::endl;
352           return 0;
353         }
354         const float *data = test->getData();
355         unsigned int len = test->getDim().getDataLen();
356         std::vector<float> temp(data, data + len);
357         action.push_back(argmax(temp));
358
359         std::cout << "qvalues : [";
360         std::cout.width(10);
361         std::cout << temp[0] << "][";
362         std::cout.width(10);
363         std::cout << temp[1] << "] : ACTION (argmax) = ";
364         std::cout.width(3);
365         std::cout << argmax(temp) << "\n";
366       }
367
368       /**
369        * @brief     step Env with this action & save next State in next_s
370        */
371       env->step(action, RENDER, &next_s);
372       Experience ex;
373       ex.state = s;
374       ex.action = action;
375       ex.reward = next_s.reward;
376       ex.next_state = next_s;
377       ex.done = next_s.done;
378
379       if (expQ.size() > REPLAY_MEMORY) {
380         expQ.pop_front();
381       }
382
383       done = next_s.done;
384
385       /**
386        * @brief     Set Penalty or reward
387        */
388       if (done) {
389         std::cout << "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! DONE : Episode "
390                   << episode << " Iteration : " << step_count << "\n";
391         ex.reward = -100.0;
392         if (!TRAINING)
393           break;
394       }
395
396       /**
397        * @brief     Save at the Experience Replay Buffer
398        */
399       expQ.push_back(ex);
400
401       s = next_s;
402       step_count++;
403
404       if (step_count > 10000) {
405         std::cout << "step_count is over 10000\n";
406         break;
407       }
408     }
409     if (step_count > 10000)
410       break;
411
412     if (!TRAINING && done)
413       break;
414
415     /**
416      * @brief     Training after finishing 10 episodes
417      */
418     if (episode % 10 == 1 && TRAINING) {
419       for (int iter = 0; iter < 50; iter++) {
420         /**
421          * @brief     Get batch size of Experience
422          */
423         std::vector<Experience> in_Exp = getBatchSizeData(expQ);
424         std::vector<std::vector<std::vector<std::vector<float>>>> inbatch;
425         std::vector<std::vector<std::vector<std::vector<float>>>> next_inbatch;
426
427         /**
428          * @brief     Generate Lable with next state
429          */
430         for (unsigned int i = 0; i < in_Exp.size(); i++) {
431           STATE state = in_Exp[i].state;
432           STATE next_state = in_Exp[i].next_state;
433           std::vector<float> in(state.observation.begin(),
434                                 state.observation.end());
435           inbatch.push_back({{in}});
436
437           std::vector<float> next_in(next_state.observation.begin(),
438                                      next_state.observation.end());
439           next_inbatch.push_back({{next_in}});
440         }
441
442         nntrainer::Tensor q_in, nq_in;
443         try {
444           q_in = nntrainer::Tensor(
445             inbatch, {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16});
446           nq_in = nntrainer::Tensor(next_inbatch, {nntrainer::Tformat::NCHW,
447                                                    nntrainer::Tdatatype::FP16});
448         } catch (...) {
449           std::cerr << "Error during tensor constructino" << std::endl;
450           return 0;
451         }
452
453         /**
454          * @brief     run forward propagation with mainNet
455          */
456         nntrainer::sharedConstTensor Q;
457         try {
458           Q = mainNet.forwarding({MAKE_SHARED_TENSOR(q_in)})[0];
459         } catch (...) {
460           std::cerr << "Error during forwarding main network" << std::endl;
461           return -1;
462         }
463
464         /**
465          * @brief     run forward propagation with targetNet
466          */
467         nntrainer::sharedConstTensor NQ;
468         try {
469           NQ = targetNet.forwarding({MAKE_SHARED_TENSOR(nq_in)})[0];
470         } catch (...) {
471           std::cerr << "Error during forwarding target network" << std::endl;
472           return -1;
473         }
474         const float *nqa = NQ->getData();
475
476         /**
477          * @brief     Update Q values & udpate mainNetwork
478          */
479         nntrainer::Tensor tempQ = *Q;
480         for (unsigned int i = 0; i < in_Exp.size(); i++) {
481           if (in_Exp[i].done) {
482             try {
483               tempQ.setValue(i, 0, 0, (int)in_Exp[i].action[0],
484                              (float)in_Exp[i].reward);
485             } catch (...) {
486               std::cerr << "Error during set a value" << std::endl;
487               return -1;
488             }
489           } else {
490             float next = (nqa[i * NQ->width()] > nqa[i * NQ->width() + 1])
491                            ? nqa[i * NQ->width()]
492                            : nqa[i * NQ->width() + 1];
493             try {
494               tempQ.setValue(i, 0, 0, (int)in_Exp[i].action[0],
495                              (float)in_Exp[i].reward + DISCOUNT * next);
496             } catch (...) {
497               std::cerr << "Error during set value" << std::endl;
498               return -1;
499             }
500           }
501         }
502         nntrainer::Tensor in_tensor;
503         try {
504           in_tensor = nntrainer::Tensor(
505             inbatch, {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16});
506           mainNet.forwarding({MAKE_SHARED_TENSOR(in_tensor)}, {Q});
507           mainNet.backwarding(iter);
508         } catch (...) {
509           std::cerr << "Error during backwarding the network" << std::endl;
510           return -1;
511         }
512       }
513
514       try {
515         writeFile << "mainNet Loss : " << mainNet.getLoss()
516                   << " : targetNet Loss : " << targetNet.getLoss() << "\n";
517         std::cout << "\n\n =================== TRAINIG & COPY NET "
518                      "==================\n\n";
519         std::cout << "mainNet Loss : ";
520         std::cout.width(15);
521         std::cout << mainNet.getLoss() << "\n targetNet Loss : ";
522         std::cout.width(15);
523         std::cout << targetNet.getLoss() << "\n\n";
524       } catch (std::exception &e) {
525         std::cerr << "Error during getLoss: " << e.what() << "\n";
526         return 1;
527       }
528
529       try {
530         targetNet.load(weight_file, ml::train::ModelFormat::MODEL_FORMAT_BIN);
531         mainNet.save(weight_file, ml::train::ModelFormat::MODEL_FORMAT_BIN);
532       } catch (std::exception &e) {
533         std::cerr << "Error during saveBin: " << e.what() << "\n";
534         return 1;
535       }
536     }
537   }
538
539   writeFile.close();
540   return 0;
541 }