Fix file permissions.
[platform/core/security/key-manager.git] / src / manager / dpl / db / src / sql_connection.cpp
1 /*
2  * Copyright (c) 2014 Samsung Electronics Co., Ltd All Rights Reserved
3  *
4  *    Licensed under the Apache License, Version 2.0 (the "License");
5  *    you may not use this file except in compliance with the License.
6  *    You may obtain a copy of the License at
7  *
8  *        http://www.apache.org/licenses/LICENSE-2.0
9  *
10  *    Unless required by applicable law or agreed to in writing, software
11  *    distributed under the License is distributed on an "AS IS" BASIS,
12  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  *    See the License for the specific language governing permissions and
14  *    limitations under the License.
15  */
16 /*
17  * @file        sql_connection.cpp
18  * @author      Przemyslaw Dobrowolski (p.dobrowolsk@samsung.com)
19  * @version     1.0
20  * @brief       This file is the implementation file of SQL connection
21  */
22 #pragma GCC diagnostic push
23 #pragma GCC diagnostic warning "-Wdeprecated-declarations"
24
25 #include <stddef.h>
26 #include <dpl/db/sql_connection.h>
27 #include <dpl/db/naive_synchronization_object.h>
28 #include <dpl/assert.h>
29 #include <dpl/scoped_ptr.h>
30 #include <unistd.h>
31 #include <cstdio>
32 #include <cstdarg>
33 #include <memory>
34 #include <noncopyable.h>
35
36
37 namespace {
38 const int MAX_RETRY = 10;
39
40 struct ScopedVaList {
41     ~ScopedVaList() { va_end(args); }
42     va_list args;
43 };
44
45 #define scoped_va_start(name, param) ScopedVaList name; va_start(name.args, param);
46 }
47
48 namespace CKM {
49 namespace DB {
50 namespace { // anonymous
51 class ScopedNotifyAll
52 {
53   private:
54     SqlConnection::SynchronizationObject *m_synchronizationObject;
55
56   public:
57     NONCOPYABLE(ScopedNotifyAll)
58
59     explicit ScopedNotifyAll(
60         SqlConnection::SynchronizationObject *synchronizationObject) :
61         m_synchronizationObject(synchronizationObject)
62     {
63     }
64
65     ~ScopedNotifyAll()
66     {
67         if (!m_synchronizationObject)
68             return;
69
70         LogPedantic("Notifying after successful synchronize");
71         m_synchronizationObject->NotifyAll();
72     }
73 };
74 } // namespace anonymous
75
76 SqlConnection::DataCommand::DataCommand(SqlConnection *connection,
77                                         const char *buffer) :
78     m_masterConnection(connection),
79     m_stmt(NULL)
80 {
81     Assert(connection != NULL);
82
83     // Notify all after potentially synchronized database connection access
84     ScopedNotifyAll notifyAll(connection->m_synchronizationObject.get());
85
86     for (int i = 0; i < MAX_RETRY; i++) {
87         int ret = sqlcipher3_prepare_v2(connection->m_connection,
88                                      buffer, strlen(buffer),
89                                      &m_stmt, NULL);
90
91         if (ret == SQLCIPHER_OK) {
92             LogPedantic("Prepared data command: " << buffer);
93
94             // Increment stored data command count
95             ++m_masterConnection->m_dataCommandsCount;
96             return;
97         } else if (ret == SQLCIPHER_BUSY) {
98             LogPedantic("Collision occurred while preparing SQL command");
99
100             // Synchronize if synchronization object is available
101             if (connection->m_synchronizationObject) {
102                 LogPedantic("Performing synchronization");
103                 connection->m_synchronizationObject->Synchronize();
104                 continue;
105             }
106
107             // No synchronization object defined. Fail.
108         }
109
110         // Fatal error
111         const char *error = sqlcipher3_errmsg(m_masterConnection->m_connection);
112
113         LogError("SQL prepare data command failed");
114         LogError("    Statement: " << buffer);
115         LogError("    Error: " << error);
116
117         ThrowMsg(Exception::SyntaxError, error);
118     }
119
120     LogError("sqlite in the state of possible infinite loop");
121     ThrowMsg(Exception::InternalError, "sqlite permanently busy");
122 }
123
124 SqlConnection::DataCommand::~DataCommand()
125 {
126     LogPedantic("SQL data command finalizing");
127
128     if (sqlcipher3_finalize(m_stmt) != SQLCIPHER_OK)
129         LogError("Failed to finalize data command");
130
131     // Decrement stored data command count
132     --m_masterConnection->m_dataCommandsCount;
133 }
134
135 void SqlConnection::DataCommand::CheckBindResult(int result)
136 {
137     if (result != SQLCIPHER_OK) {
138         const char *error = sqlcipher3_errmsg(
139                 m_masterConnection->m_connection);
140
141         LogError("Failed to bind SQL statement parameter");
142         LogError("    Error: " << error);
143
144         ThrowMsg(Exception::SyntaxError, error);
145     }
146 }
147
148 void SqlConnection::DataCommand::BindNull(
149     SqlConnection::ArgumentIndex position)
150 {
151     CheckBindResult(sqlcipher3_bind_null(m_stmt, position));
152     LogPedantic("SQL data command bind null: ["
153                 << position << "]");
154 }
155
156 void SqlConnection::DataCommand::BindInteger(
157     SqlConnection::ArgumentIndex position,
158     int value)
159 {
160     CheckBindResult(sqlcipher3_bind_int(m_stmt, position, value));
161     LogPedantic("SQL data command bind integer: ["
162                 << position << "] -> " << value);
163 }
164
165 void SqlConnection::DataCommand::BindInt8(
166     SqlConnection::ArgumentIndex position,
167     int8_t value)
168 {
169     CheckBindResult(sqlcipher3_bind_int(m_stmt, position,
170                                      static_cast<int>(value)));
171     LogPedantic("SQL data command bind int8: ["
172                 << position << "] -> " << value);
173 }
174
175 void SqlConnection::DataCommand::BindInt16(
176     SqlConnection::ArgumentIndex position,
177     int16_t value)
178 {
179     CheckBindResult(sqlcipher3_bind_int(m_stmt, position,
180                                      static_cast<int>(value)));
181     LogPedantic("SQL data command bind int16: ["
182                 << position << "] -> " << value);
183 }
184
185 void SqlConnection::DataCommand::BindInt32(
186     SqlConnection::ArgumentIndex position,
187     int32_t value)
188 {
189     CheckBindResult(sqlcipher3_bind_int(m_stmt, position,
190                                      static_cast<int>(value)));
191     LogPedantic("SQL data command bind int32: ["
192                 << position << "] -> " << value);
193 }
194
195 void SqlConnection::DataCommand::BindInt64(
196     SqlConnection::ArgumentIndex position,
197     int64_t value)
198 {
199     CheckBindResult(sqlcipher3_bind_int64(m_stmt, position,
200                                        static_cast<sqlcipher3_int64>(value)));
201     LogPedantic("SQL data command bind int64: ["
202                 << position << "] -> " << value);
203 }
204
205 void SqlConnection::DataCommand::BindFloat(
206     SqlConnection::ArgumentIndex position,
207     float value)
208 {
209     CheckBindResult(sqlcipher3_bind_double(m_stmt, position,
210                                         static_cast<double>(value)));
211     LogPedantic("SQL data command bind float: ["
212                 << position << "] -> " << value);
213 }
214
215 void SqlConnection::DataCommand::BindDouble(
216     SqlConnection::ArgumentIndex position,
217     double value)
218 {
219     CheckBindResult(sqlcipher3_bind_double(m_stmt, position, value));
220     LogPedantic("SQL data command bind double: ["
221                 << position << "] -> " << value);
222 }
223
224 void SqlConnection::DataCommand::BindString(
225     SqlConnection::ArgumentIndex position,
226     const char *value)
227 {
228     if (!value) {
229         BindNull(position);
230         return;
231     }
232
233     // Assume that text may disappear
234     CheckBindResult(sqlcipher3_bind_text(m_stmt, position,
235                                       value, strlen(value),
236                                       SQLCIPHER_TRANSIENT));
237
238     LogPedantic("SQL data command bind string: ["
239                 << position << "] -> " << value);
240 }
241
242 void SqlConnection::DataCommand::BindBlob(
243     SqlConnection::ArgumentIndex position,
244     const RawBuffer &raw)
245 {
246     if (raw.size() == 0) {
247         BindNull(position);
248         return;
249     }
250
251     // Assume that blob may dissappear
252     CheckBindResult(sqlcipher3_bind_blob(m_stmt, position,
253                                       raw.data(), raw.size(),
254                                       SQLCIPHER_TRANSIENT));
255     LogPedantic("SQL data command bind blob of size: ["
256                 << position << "] -> " << raw.size());
257 }
258
259 void SqlConnection::DataCommand::BindInteger(
260     SqlConnection::ArgumentIndex position,
261     const boost::optional<int> &value)
262 {
263     if (!value)
264         BindNull(position);
265     else
266         BindInteger(position, *value);
267 }
268
269 void SqlConnection::DataCommand::BindInt8(
270     SqlConnection::ArgumentIndex position,
271     const boost::optional<int8_t> &value)
272 {
273     if (!value)
274         BindNull(position);
275     else
276         BindInt8(position, *value);
277 }
278
279 void SqlConnection::DataCommand::BindInt16(
280     SqlConnection::ArgumentIndex position,
281     const boost::optional<int16_t> &value)
282 {
283     if (!value)
284         BindNull(position);
285     else
286         BindInt16(position, *value);
287 }
288
289 void SqlConnection::DataCommand::BindInt32(
290     SqlConnection::ArgumentIndex position,
291     const boost::optional<int32_t> &value)
292 {
293     if (!value)
294         BindNull(position);
295     else
296         BindInt32(position, *value);
297 }
298
299 void SqlConnection::DataCommand::BindInt64(
300     SqlConnection::ArgumentIndex position,
301     const boost::optional<int64_t> &value)
302 {
303     if (!value)
304         BindNull(position);
305     else
306         BindInt64(position, *value);
307 }
308
309 void SqlConnection::DataCommand::BindFloat(
310     SqlConnection::ArgumentIndex position,
311     const boost::optional<float> &value)
312 {
313     if (!value)
314         BindNull(position);
315     else
316         BindFloat(position, *value);
317 }
318
319 void SqlConnection::DataCommand::BindDouble(
320     SqlConnection::ArgumentIndex position,
321     const boost::optional<double> &value)
322 {
323     if (!value)
324         BindNull(position);
325     else
326         BindDouble(position, *value);
327 }
328
329 void SqlConnection::DataCommand::BindBlob(
330     SqlConnection::ArgumentIndex position,
331     const boost::optional<RawBuffer> &value)
332 {
333     if (!!value)
334         BindBlob(position, *value);
335     else
336         BindNull(position);
337 }
338
339 bool SqlConnection::DataCommand::Step()
340 {
341     // Notify all after potentially synchronized database connection access
342     ScopedNotifyAll notifyAll(
343         m_masterConnection->m_synchronizationObject.get());
344
345     for (int i = 0; i < MAX_RETRY; i++) {
346         int ret = sqlcipher3_step(m_stmt);
347
348         if (ret == SQLCIPHER_ROW) {
349             LogPedantic("SQL data command step ROW");
350             return true;
351         } else if (ret == SQLCIPHER_DONE) {
352             LogPedantic("SQL data command step DONE");
353             return false;
354         } else if (ret == SQLCIPHER_BUSY) {
355             LogPedantic("Collision occurred while executing SQL command");
356
357             // Synchronize if synchronization object is available
358             if (m_masterConnection->m_synchronizationObject) {
359                 LogPedantic("Performing synchronization");
360
361                 m_masterConnection->
362                     m_synchronizationObject->Synchronize();
363
364                 continue;
365             }
366             // No synchronization object defined. Fail.
367         }
368
369         // Fatal error
370         const char *error = sqlcipher3_errmsg(m_masterConnection->m_connection);
371
372         LogError("SQL step data command failed");
373         LogError("    Error: " << error);
374
375         ThrowMsg(Exception::InternalError, error);
376     }
377
378     LogError("sqlite in the state of possible infinite loop");
379     ThrowMsg(Exception::InternalError, "sqlite permanently busy");
380 }
381
382 void SqlConnection::DataCommand::Reset()
383 {
384     /*
385      * According to:
386      * http://www.sqllite.org/c3ref/stmt.html
387      *
388      * if last sqlcipher3_step command on this stmt returned an error,
389      * then sqlcipher3_reset will return that error, althought it is not an error.
390      * So sqlcipher3_reset allways succedes.
391      */
392     sqlcipher3_reset(m_stmt);
393
394     LogPedantic("SQL data command reset");
395 }
396
397 void SqlConnection::DataCommand::CheckColumnIndex(
398     SqlConnection::ColumnIndex column)
399 {
400     if (column < 0 || column >= sqlcipher3_column_count(m_stmt))
401         ThrowMsg(Exception::InvalidColumn, "Column index is out of bounds");
402 }
403
404 bool SqlConnection::DataCommand::IsColumnNull(
405     SqlConnection::ColumnIndex column)
406 {
407     LogPedantic("SQL data command get column type: [" << column << "]");
408     CheckColumnIndex(column);
409     return sqlcipher3_column_type(m_stmt, column) == SQLCIPHER_NULL;
410 }
411
412 int SqlConnection::DataCommand::GetColumnInteger(
413     SqlConnection::ColumnIndex column)
414 {
415     LogPedantic("SQL data command get column integer: [" << column << "]");
416     CheckColumnIndex(column);
417     int value = sqlcipher3_column_int(m_stmt, column);
418     LogPedantic("    Value: " << value);
419     return value;
420 }
421
422 int8_t SqlConnection::DataCommand::GetColumnInt8(
423     SqlConnection::ColumnIndex column)
424 {
425     LogPedantic("SQL data command get column int8: [" << column << "]");
426     CheckColumnIndex(column);
427     int8_t value = static_cast<int8_t>(sqlcipher3_column_int(m_stmt, column));
428     LogPedantic("    Value: " << value);
429     return value;
430 }
431
432 int16_t SqlConnection::DataCommand::GetColumnInt16(
433     SqlConnection::ColumnIndex column)
434 {
435     LogPedantic("SQL data command get column int16: [" << column << "]");
436     CheckColumnIndex(column);
437     int16_t value = static_cast<int16_t>(sqlcipher3_column_int(m_stmt, column));
438     LogPedantic("    Value: " << value);
439     return value;
440 }
441
442 int32_t SqlConnection::DataCommand::GetColumnInt32(
443     SqlConnection::ColumnIndex column)
444 {
445     LogPedantic("SQL data command get column int32: [" << column << "]");
446     CheckColumnIndex(column);
447     int32_t value = static_cast<int32_t>(sqlcipher3_column_int(m_stmt, column));
448     LogPedantic("    Value: " << value);
449     return value;
450 }
451
452 int64_t SqlConnection::DataCommand::GetColumnInt64(
453     SqlConnection::ColumnIndex column)
454 {
455     LogPedantic("SQL data command get column int64: [" << column << "]");
456     CheckColumnIndex(column);
457     int64_t value = static_cast<int64_t>(sqlcipher3_column_int64(m_stmt, column));
458     LogPedantic("    Value: " << value);
459     return value;
460 }
461
462 float SqlConnection::DataCommand::GetColumnFloat(
463     SqlConnection::ColumnIndex column)
464 {
465     LogPedantic("SQL data command get column float: [" << column << "]");
466     CheckColumnIndex(column);
467     float value = static_cast<float>(sqlcipher3_column_double(m_stmt, column));
468     LogPedantic("    Value: " << value);
469     return value;
470 }
471
472 double SqlConnection::DataCommand::GetColumnDouble(
473     SqlConnection::ColumnIndex column)
474 {
475     LogPedantic("SQL data command get column double: [" << column << "]");
476     CheckColumnIndex(column);
477     double value = sqlcipher3_column_double(m_stmt, column);
478     LogPedantic("    Value: " << value);
479     return value;
480 }
481
482 std::string SqlConnection::DataCommand::GetColumnString(
483     SqlConnection::ColumnIndex column)
484 {
485     LogPedantic("SQL data command get column string: [" << column << "]");
486     CheckColumnIndex(column);
487
488     const char *value = reinterpret_cast<const char *>(
489             sqlcipher3_column_text(m_stmt, column));
490
491     LogPedantic("Value: " << (value ? value : "NULL"));
492
493     if (value == NULL)
494         return std::string();
495
496     return std::string(value);
497 }
498
499 RawBuffer SqlConnection::DataCommand::GetColumnBlob(
500     SqlConnection::ColumnIndex column)
501 {
502     LogPedantic("SQL data command get column blog: [" << column << "]");
503     CheckColumnIndex(column);
504
505     const unsigned char *value = reinterpret_cast<const unsigned char*>(
506             sqlcipher3_column_blob(m_stmt, column));
507
508     if (value == NULL)
509         return RawBuffer();
510
511     int length = sqlcipher3_column_bytes(m_stmt, column);
512     LogPedantic("Got blob of length: " << length);
513
514     return RawBuffer(value, value + length);
515 }
516
517 boost::optional<int> SqlConnection::DataCommand::GetColumnOptionalInteger(
518     SqlConnection::ColumnIndex column)
519 {
520     LogPedantic("SQL data command get column optional integer: ["
521                 << column << "]");
522     CheckColumnIndex(column);
523     if (sqlcipher3_column_type(m_stmt, column) == SQLCIPHER_NULL)
524         return boost::optional<int>();
525
526     int value = sqlcipher3_column_int(m_stmt, column);
527     LogPedantic("    Value: " << value);
528     return boost::optional<int>(value);
529 }
530
531 boost::optional<int8_t> SqlConnection::DataCommand::GetColumnOptionalInt8(
532     SqlConnection::ColumnIndex column)
533 {
534     LogPedantic("SQL data command get column optional int8: ["
535                 << column << "]");
536     CheckColumnIndex(column);
537     if (sqlcipher3_column_type(m_stmt, column) == SQLCIPHER_NULL)
538         return boost::optional<int8_t>();
539
540     int8_t value = static_cast<int8_t>(sqlcipher3_column_int(m_stmt, column));
541     LogPedantic("    Value: " << value);
542     return boost::optional<int8_t>(value);
543 }
544
545 boost::optional<int16_t> SqlConnection::DataCommand::GetColumnOptionalInt16(
546     SqlConnection::ColumnIndex column)
547 {
548     LogPedantic("SQL data command get column optional int16: ["
549                 << column << "]");
550     CheckColumnIndex(column);
551     if (sqlcipher3_column_type(m_stmt, column) == SQLCIPHER_NULL)
552         return boost::optional<int16_t>();
553
554     int16_t value = static_cast<int16_t>(sqlcipher3_column_int(m_stmt, column));
555     LogPedantic("    Value: " << value);
556     return boost::optional<int16_t>(value);
557 }
558
559 boost::optional<int32_t> SqlConnection::DataCommand::GetColumnOptionalInt32(
560     SqlConnection::ColumnIndex column)
561 {
562     LogPedantic("SQL data command get column optional int32: ["
563                 << column << "]");
564     CheckColumnIndex(column);
565     if (sqlcipher3_column_type(m_stmt, column) == SQLCIPHER_NULL)
566         return boost::optional<int32_t>();
567
568     int32_t value = static_cast<int32_t>(sqlcipher3_column_int(m_stmt, column));
569     LogPedantic("    Value: " << value);
570     return boost::optional<int32_t>(value);
571 }
572
573 boost::optional<int64_t> SqlConnection::DataCommand::GetColumnOptionalInt64(
574     SqlConnection::ColumnIndex column)
575 {
576     LogPedantic("SQL data command get column optional int64: ["
577                 << column << "]");
578     CheckColumnIndex(column);
579     if (sqlcipher3_column_type(m_stmt, column) == SQLCIPHER_NULL)
580         return boost::optional<int64_t>();
581
582     int64_t value = static_cast<int64_t>(sqlcipher3_column_int64(m_stmt, column));
583     LogPedantic("    Value: " << value);
584     return boost::optional<int64_t>(value);
585 }
586
587 boost::optional<float> SqlConnection::DataCommand::GetColumnOptionalFloat(
588     SqlConnection::ColumnIndex column)
589 {
590     LogPedantic("SQL data command get column optional float: ["
591                 << column << "]");
592     CheckColumnIndex(column);
593     if (sqlcipher3_column_type(m_stmt, column) == SQLCIPHER_NULL)
594         return boost::optional<float>();
595
596     float value = static_cast<float>(sqlcipher3_column_double(m_stmt, column));
597     LogPedantic("    Value: " << value);
598     return boost::optional<float>(value);
599 }
600
601 boost::optional<double> SqlConnection::DataCommand::GetColumnOptionalDouble(
602     SqlConnection::ColumnIndex column)
603 {
604     LogPedantic("SQL data command get column optional double: ["
605                 << column << "]");
606     CheckColumnIndex(column);
607     if (sqlcipher3_column_type(m_stmt, column) == SQLCIPHER_NULL)
608         return boost::optional<double>();
609
610     double value = sqlcipher3_column_double(m_stmt, column);
611     LogPedantic("    Value: " << value);
612     return boost::optional<double>(value);
613 }
614
615 boost::optional<RawBuffer> SqlConnection::DataCommand::GetColumnOptionalBlob(
616     SqlConnection::ColumnIndex column)
617 {
618     LogPedantic("SQL data command get column blog: [" << column << "]");
619     CheckColumnIndex(column);
620
621     if (sqlcipher3_column_type(m_stmt, column) == SQLCIPHER_NULL)
622         return boost::optional<RawBuffer>();
623
624     const unsigned char *value = reinterpret_cast<const unsigned char*>(
625             sqlcipher3_column_blob(m_stmt, column));
626
627     int length = sqlcipher3_column_bytes(m_stmt, column);
628     LogPedantic("Got blob of length: " << length);
629
630     RawBuffer temp(value, value + length);
631     return boost::optional<RawBuffer>(temp);
632 }
633
634 void SqlConnection::Connect(const std::string &address,
635                             Flag::Option flag)
636 {
637     if (m_connection != NULL) {
638         LogPedantic("Already connected.");
639         return;
640     }
641     LogPedantic("Connecting to DB: " << address << "...");
642
643     // Connect to database
644     int result;
645     result = sqlcipher3_open_v2(
646             address.c_str(),
647             &m_connection,
648             flag,
649             NULL);
650
651     if (result == SQLCIPHER_OK) {
652         LogPedantic("Connected to DB");
653     } else {
654         LogError("Failed to connect to DB!");
655         ThrowMsg(Exception::ConnectionBroken, address);
656     }
657
658     // Enable foreign keys
659     TurnOnForeignKeys();
660 }
661
662 const std::string SQLCIPHER_RAW_PREFIX = "x'";
663 const std::string SQLCIPHER_RAW_SUFIX = "'";
664 const std::size_t SQLCIPHER_RAW_DATA_SIZE = 32;
665
666 RawBuffer rawToHexString(const RawBuffer &raw)
667 {
668     RawBuffer output;
669     for (auto &e: raw) {
670         char result[3];
671         snprintf(result, sizeof(result), "%02X", static_cast<unsigned int>(e));
672         output.push_back(static_cast<unsigned char>(result[0]));
673         output.push_back(static_cast<unsigned char>(result[1]));
674     }
675     return output;
676 }
677
678 RawBuffer createHexPass(const RawBuffer &rawPass)
679 {
680     // We are required to pass 64byte long hex password made out of 32byte raw
681     // binary data
682     RawBuffer output;
683     std::copy(SQLCIPHER_RAW_PREFIX.begin(), SQLCIPHER_RAW_PREFIX.end(),
684         std::back_inserter(output));
685
686     RawBuffer password = rawToHexString(rawPass);
687
688     std::copy(password.begin(), password.end(),
689         std::back_inserter(output));
690
691     std::copy(SQLCIPHER_RAW_SUFIX.begin(), SQLCIPHER_RAW_SUFIX.end(),
692         std::back_inserter(output));
693
694     return output;
695 }
696
697 void SqlConnection::SetKey(const RawBuffer &rawPass){
698     if (m_connection == NULL) {
699         LogPedantic("Cannot set key. No connection to DB!");
700         return;
701     }
702     if (rawPass.size() != SQLCIPHER_RAW_DATA_SIZE)
703             ThrowMsg(Exception::InvalidArguments,
704                     "Binary data for raw password should be 32 bytes long.");
705     RawBuffer pass = createHexPass(rawPass);
706     int result = sqlcipher3_key(m_connection, pass.data(), pass.size());
707     if (result == SQLCIPHER_OK) {
708         LogPedantic("Set key on DB");
709     } else {
710         //sqlcipher3_key fails only when m_connection == NULL || key == NULL ||
711         //                            key length == 0
712         LogError("Failed to set key on DB");
713         ThrowMsg(Exception::InvalidArguments, result);
714     }
715
716     m_isKeySet = true;
717 };
718
719 void SqlConnection::ResetKey(const RawBuffer &rawPassOld,
720                              const RawBuffer &rawPassNew) {
721     if (m_connection == NULL) {
722         LogPedantic("Cannot reset key. No connection to DB!");
723         return;
724     }
725     AssertMsg(rawPassOld.size() == SQLCIPHER_RAW_DATA_SIZE &&
726               rawPassNew.size() == SQLCIPHER_RAW_DATA_SIZE,
727             "Binary data for raw password should be 32 bytes long.");
728     // sqlcipher3_rekey requires for key to be already set
729     if (!m_isKeySet)
730         SetKey(rawPassOld);
731
732     RawBuffer pass = createHexPass(rawPassNew);
733     int result = sqlcipher3_rekey(m_connection, pass.data(), pass.size());
734     if (result == SQLCIPHER_OK) {
735         LogPedantic("Reset key on DB");
736     } else {
737         //sqlcipher3_rekey fails only when m_connection == NULL || key == NULL ||
738         //                              key length == 0
739         LogError("Failed to reset key on DB");
740         ThrowMsg(Exception::InvalidArguments, result);
741     }
742 }
743
744 void SqlConnection::Disconnect()
745 {
746     if (m_connection == NULL) {
747         LogPedantic("Already disconnected.");
748         return;
749     }
750
751     LogPedantic("Disconnecting from DB...");
752
753     // All stored data commands must be deleted before disconnect
754     AssertMsg(m_dataCommandsCount == 0,
755            "All stored procedures must be deleted"
756            " before disconnecting SqlConnection");
757
758     int result;
759
760     result = sqlcipher3_close(m_connection);
761
762     if (result != SQLCIPHER_OK) {
763         const char *error = sqlcipher3_errmsg(m_connection);
764         LogError("SQL close failed");
765         LogError("    Error: " << error);
766         Throw(Exception::InternalError);
767     }
768
769     m_connection = NULL;
770
771     LogPedantic("Disconnected from DB");
772 }
773
774 bool SqlConnection::CheckTableExist(const char *tableName)
775 {
776     if (m_connection == NULL) {
777         LogPedantic("Cannot execute command. Not connected to DB!");
778         return false;
779     }
780
781     DataCommandUniquePtr command =
782         PrepareDataCommand("select tbl_name from sqlcipher_master where name=?;");
783
784     command->BindString(1, tableName);
785
786     if (!command->Step()) {
787         LogPedantic("No matching records in table");
788         return false;
789     }
790
791     return command->GetColumnString(0) == tableName;
792 }
793
794 SqlConnection::SqlConnection(const std::string &address,
795                              Flag::Option option,
796                              SynchronizationObject *synchronizationObject) :
797     m_connection(NULL),
798     m_dataCommandsCount(0),
799     m_synchronizationObject(synchronizationObject),
800     m_isKeySet(false)
801 {
802     LogPedantic("Opening database connection to: " << address);
803
804     // Connect to DB
805     SqlConnection::Connect(address, option);
806
807     if (!m_synchronizationObject)
808         LogPedantic("No synchronization object defined");
809 }
810
811 SqlConnection::~SqlConnection()
812 {
813     LogPedantic("Closing database connection");
814
815     // Disconnect from DB
816     Try
817     {
818         SqlConnection::Disconnect();
819     }
820     Catch(Exception::Base)
821     {
822         LogError("Failed to disconnect from database");
823     }
824 }
825
826 int SqlConnection::Output::Callback(void* param, int columns, char** values, char** names)
827 {
828     if (param)
829         static_cast<Output*>(param)->SetResults(columns, values, names);
830     return 0;
831 }
832
833 void SqlConnection::Output::SetResults(int columns, char** values, char** names)
834 {
835     if (m_names.empty()) {
836         for (int i=0; i < columns; i++)
837             m_names.push_back(names[i] ? names[i] : "NULL");
838     }
839     Row row;
840     for (int i=0; i < columns; i++)
841         row.push_back(values[i] ? values[i] : "NULL");
842     m_values.push_back(std::move(row));
843 }
844
845 void SqlConnection::ExecCommandHelper(Output* out, const char* format, va_list args)
846 {
847     if (m_connection == NULL) {
848         LogError("Cannot execute command. Not connected to DB!");
849         return;
850     }
851
852     if (format == NULL) {
853         LogError("Null query!");
854         ThrowMsg(Exception::SyntaxError, "Null statement");
855     }
856
857     char *query;
858
859     if (vasprintf(&query, format, args) == -1) {
860         LogError("Failed to allocate statement string");
861         return;
862     }
863
864     CharUniquePtr queryPtr(query);
865
866     LogPedantic("Executing SQL command: " << queryPtr.get());
867
868     // Notify all after potentially synchronized database connection access
869     ScopedNotifyAll notifyAll(m_synchronizationObject.get());
870
871     for (int i = 0; i < MAX_RETRY; i++) {
872         char *errorBuffer;
873         int ret = sqlcipher3_exec(m_connection,
874                                   queryPtr.get(),
875                                   out ? &Output::Callback : NULL,
876                                   out,
877                                   &errorBuffer);
878
879         std::string errorMsg;
880
881         // Take allocated error buffer
882         if (errorBuffer != NULL) {
883             errorMsg = errorBuffer;
884             sqlcipher3_free(errorBuffer);
885         }
886
887         if (ret == SQLCIPHER_OK)
888             return;
889
890         if (ret == SQLCIPHER_BUSY) {
891             LogPedantic("Collision occurred while executing SQL command");
892
893             // Synchronize if synchronization object is available
894             if (m_synchronizationObject) {
895                 LogPedantic("Performing synchronization");
896                 m_synchronizationObject->Synchronize();
897                 continue;
898             }
899
900             // No synchronization object defined. Fail.
901         }
902
903         // Fatal error
904         LogError("Failed to execute SQL command. Error: " << errorMsg);
905         ThrowMsg(Exception::SyntaxError, errorMsg);
906     }
907
908     LogError("sqlite in the state of possible infinite loop");
909     ThrowMsg(Exception::InternalError, "sqlite permanently busy");
910 }
911
912 void SqlConnection::ExecCommand(Output* out, const char *format, ...)
913 {
914     scoped_va_start(svl, format);
915
916     ExecCommandHelper(out, format, svl.args);
917 }
918
919 void SqlConnection::ExecCommand(const char *format, ...)
920 {
921     scoped_va_start(svl, format);
922
923     ExecCommandHelper(NULL, format, svl.args);
924 }
925
926 SqlConnection::DataCommandUniquePtr SqlConnection::PrepareDataCommand(
927     const char *format,
928     ...)
929 {
930     if (m_connection == NULL) {
931         LogError("Cannot execute data command. Not connected to DB!");
932         return DataCommandUniquePtr();
933     }
934
935     char *rawBuffer;
936
937     va_list args;
938     va_start(args, format);
939
940     if (vasprintf(&rawBuffer, format, args) == -1)
941         rawBuffer = NULL;
942
943     va_end(args);
944
945     CharUniquePtr buffer(rawBuffer);
946
947     if (!buffer) {
948         LogError("Failed to allocate statement string");
949         return DataCommandUniquePtr();
950     }
951
952     LogPedantic("Executing SQL data command: " << buffer.get());
953
954     return DataCommandUniquePtr(new DataCommand(this, buffer.get()));
955 }
956
957 SqlConnection::RowID SqlConnection::GetLastInsertRowID() const
958 {
959     return static_cast<RowID>(sqlcipher3_last_insert_rowid(m_connection));
960 }
961
962 void SqlConnection::TurnOnForeignKeys()
963 {
964     ExecCommand("PRAGMA foreign_keys = ON;");
965 }
966
967 void SqlConnection::BeginTransaction()
968 {
969     ExecCommand("BEGIN;");
970 }
971
972 void SqlConnection::RollbackTransaction()
973 {
974     ExecCommand("ROLLBACK;");
975 }
976
977 void SqlConnection::CommitTransaction()
978 {
979     ExecCommand("COMMIT;");
980 }
981
982 SqlConnection::SynchronizationObject *
983 SqlConnection::AllocDefaultSynchronizationObject()
984 {
985     return new NaiveSynchronizationObject();
986 }
987 } // namespace DB
988 } // namespace CKM
989
990 #pragma GCC diagnostic pop