Thanks to Ivan Stankovic
[platform/upstream/cryptsetup.git] / lib / utils.c
1 #include <stdio.h>
2 #include <string.h>
3 #include <stdlib.h>
4 #include <stddef.h>
5 #include <stdarg.h>
6 #include <errno.h>
7 #include <linux/fs.h>
8 #include <sys/types.h>
9 #include <unistd.h>
10 #include <sys/types.h>
11 #include <sys/stat.h>
12 #include <sys/ioctl.h>
13 #include <fcntl.h>
14 #include <termios.h>
15
16 #include "libcryptsetup.h"
17 #include "internal.h"
18
19
20 struct safe_allocation {
21         size_t  size;
22         char    data[1];
23 };
24
25 static char *error=NULL;
26
27 void set_error_va(const char *fmt, va_list va)
28 {
29
30         if(error) {
31             free(error);
32             error=NULL;
33         }
34
35         vasprintf(&error, fmt, va);
36 }
37
38 void set_error(const char *fmt, ...)
39 {
40         va_list va;
41
42         va_start(va, fmt);
43         set_error_va(fmt, va);
44         va_end(va);
45 }
46
47 const char *get_error(void)
48 {
49         return error;
50 }
51
52 void *safe_alloc(size_t size)
53 {
54         struct safe_allocation *alloc;
55
56         if (!size)
57                 return NULL;
58
59         alloc = malloc(size + offsetof(struct safe_allocation, data));
60         if (!alloc)
61                 return NULL;
62
63         alloc->size = size;
64
65         return &alloc->data;
66 }
67
68 void safe_free(void *data)
69 {
70         struct safe_allocation *alloc;
71
72         if (!data)
73                 return;
74
75         alloc = data - offsetof(struct safe_allocation, data);
76
77         memset(data, 0, alloc->size);
78
79         alloc->size = 0x55aa55aa;
80         free(alloc);
81 }
82
83 void *safe_realloc(void *data, size_t size)
84 {
85         void *new_data;
86
87         new_data = safe_alloc(size);
88
89         if (new_data && data) {
90                 struct safe_allocation *alloc;
91
92                 alloc = data - offsetof(struct safe_allocation, data);
93
94                 if (size > alloc->size)
95                         size = alloc->size;
96
97                 memcpy(new_data, data, size);
98         }
99
100         safe_free(data);
101         return new_data;
102 }
103
104 char *safe_strdup(const char *s)
105 {
106         char *s2 = safe_alloc(strlen(s) + 1);
107
108         if (!s2)
109                 return NULL;
110
111         return strcpy(s2, s);
112 }
113
114 /* Credits go to Michal's padlock patches for this alignment code */
115
116 static void *aligned_malloc(char **base, int size, int alignment) 
117 {
118         char *ptr;
119
120         ptr  = malloc(size + alignment);
121         if(ptr == NULL) return NULL;
122
123         *base = ptr;
124         if(alignment > 1 && ((long)ptr & (alignment - 1))) {
125                 ptr += alignment - ((long)(ptr) & (alignment - 1));
126         }
127         return ptr;
128 }
129
130 static int sector_size(int fd) 
131 {
132         int bsize;
133         if (ioctl(fd,BLKSSZGET, &bsize) < 0)
134                 return -EINVAL;
135         else
136                 return bsize;
137 }
138
139 int sector_size_for_device(const char *device)
140 {
141         int fd = open(device, O_RDONLY);
142         int r;
143         if(fd < 0)
144                 return -EINVAL;
145         r = sector_size(fd);
146         close(fd);
147         return r;
148 }
149
150 ssize_t write_blockwise(int fd, const void *orig_buf, size_t count) 
151 {
152         char *padbuf; char *padbuf_base;
153         char *buf = (char *)orig_buf;
154         int r = 0;
155         int hangover; int solid; int bsize;
156
157         if ((bsize = sector_size(fd)) < 0)
158                 return bsize;
159
160         hangover = count % bsize;
161         solid = count - hangover;
162
163         padbuf = aligned_malloc(&padbuf_base, bsize, bsize);
164         if(padbuf == NULL) return -ENOMEM;
165
166         while(solid) {
167                 memcpy(padbuf, buf, bsize);
168                 r = write(fd, padbuf, bsize);
169                 if(r < 0 || r != bsize) goto out;
170
171                 solid -= bsize;
172                 buf += bsize;
173         }
174         if(hangover) {
175                 r = read(fd,padbuf,bsize);
176                 if(r < 0 || r != bsize) goto out;
177
178                 lseek(fd,-bsize,SEEK_CUR);
179                 memcpy(padbuf,buf,hangover);
180
181                 r = write(fd,padbuf, bsize);
182                 if(r < 0 || r != bsize) goto out;
183                 buf += hangover;
184         }
185  out:
186         free(padbuf_base);
187         return (buf-(char *)orig_buf)?(buf-(char *)orig_buf):r;
188
189 }
190
191 ssize_t read_blockwise(int fd, void *orig_buf, size_t count) {
192         char *padbuf; char *padbuf_base;
193         char *buf = (char *)orig_buf;
194         int r = 0;
195         int step;
196         int bsize;
197
198         if ((bsize = sector_size(fd)) < 0)
199                 return bsize;
200
201         padbuf = aligned_malloc(&padbuf_base, bsize, bsize);
202         if(padbuf == NULL) return -ENOMEM;
203
204         while(count) {
205                 r = read(fd,padbuf,bsize);
206                 if(r < 0 || r != bsize) {
207                         set_error("read failed in read_blockwise.\n");
208                         goto out;
209                 }
210                 step = count<bsize?count:bsize;
211                 memcpy(buf,padbuf,step);
212                 buf += step;
213                 count -= step;
214         }
215  out:
216         free(padbuf_base); 
217         return (buf-(char *)orig_buf)?(buf-(char *)orig_buf):r;
218 }
219
220 /* 
221  * Combines llseek with blockwise write. write_blockwise can already deal with short writes
222  * but we also need a function to deal with short writes at the start. But this information
223  * is implicitly included in the read/write offset, which can not be set to non-aligned 
224  * boundaries. Hence, we combine llseek with write.
225  */
226    
227 ssize_t write_lseek_blockwise(int fd, const char *buf, size_t count, off_t offset) {
228         int bsize = sector_size(fd);
229         const char *orig_buf = buf;
230         char frontPadBuf[bsize];
231         int frontHang = offset % bsize;
232         int r;
233         int innerCount = count < bsize ? count : bsize;
234
235         if (bsize < 0)
236                 return bsize;
237
238         lseek(fd, offset - frontHang, SEEK_SET);
239         if(offset % bsize) {
240                 r = read(fd,frontPadBuf,bsize);
241                 if(r < 0) return -1;
242
243                 memcpy(frontPadBuf+frontHang, buf, innerCount);
244
245                 lseek(fd, offset - frontHang, SEEK_SET);
246                 r = write(fd,frontPadBuf,bsize);
247                 if(r < 0) return -1;
248
249                 buf += innerCount;
250                 count -= innerCount;
251         }
252         if(count <= 0) return buf - orig_buf;
253
254         return write_blockwise(fd, buf, count) + innerCount;
255 }
256
257 /* Password reading helpers */
258
259 static int untimed_read(int fd, char *pass, size_t maxlen)
260 {
261         ssize_t i;
262
263         i = read(fd, pass, maxlen);
264         if (i > 0) {
265                 pass[i-1] = '\0';
266                 i = 0;
267         } else if (i == 0) { /* EOF */
268                 *pass = 0;
269                 i = -1;
270         }
271         return i;
272 }
273
274 static int timed_read(int fd, char *pass, size_t maxlen, long timeout)
275 {
276         struct timeval t;
277         fd_set fds;
278         int failed = -1;
279
280         FD_ZERO(&fds);
281         FD_SET(fd, &fds);
282         t.tv_sec = timeout;
283         t.tv_usec = 0;
284
285         if (select(fd+1, &fds, NULL, NULL, &t) > 0)
286                 failed = untimed_read(fd, pass, maxlen);
287         else
288                 set_error("Operation timed out");
289         return failed;
290 }
291
292 static int interactive_pass(const char *prompt, char *pass, size_t maxlen,
293                 long timeout)
294 {
295         struct termios orig, tmp;
296         int failed = -1;
297         int infd = STDIN_FILENO, outfd;
298
299         if (maxlen < 1)
300                 goto out_err;
301
302         /* Read and write to /dev/tty if available */
303         if ((infd = outfd = open("/dev/tty", O_RDWR)) == -1) {
304                 infd = STDIN_FILENO;
305                 outfd = STDERR_FILENO;
306         }
307
308         if (tcgetattr(infd, &orig)) {
309                 set_error("Unable to get terminal");
310                 goto out_err;
311         }
312         memcpy(&tmp, &orig, sizeof(tmp));
313         tmp.c_lflag &= ~ECHO;
314
315         write(outfd, prompt, strlen(prompt));
316         tcsetattr(infd, TCSAFLUSH, &tmp);
317         if (timeout)
318                 failed = timed_read(infd, pass, maxlen, timeout);
319         else
320                 failed = untimed_read(infd, pass, maxlen);
321         tcsetattr(infd, TCSAFLUSH, &orig);
322
323 out_err:
324         if (!failed)
325                 write(outfd, "\n", 1);
326         if (infd != STDIN_FILENO)
327                 close(infd);
328         return failed;
329 }
330
331 /*
332  * Password reading behaviour matrix of get_key
333  * 
334  *                    p   v   n   h
335  * -----------------+---+---+---+---
336  * interactive      | Y | Y | Y | Inf
337  * from fd          | N | N | Y | Inf
338  * from binary file | N | N | N | Inf or options->key_size
339  *
340  * Legend: p..prompt, v..can verify, n..newline-stop, h..read horizon
341  *
342  * Note: --key-file=- is interpreted as a read from a binary file (stdin)
343  *
344  * Returns true when more keys are available (that is when password
345  * reading can be retried as for interactive terminals).
346  */
347
348 int get_key(char *prompt, char **key, unsigned int *passLen, int key_size,
349             const char *key_file, int passphrase_fd, int timeout, int how2verify)
350 {
351         int fd;
352         const int verify = how2verify & CRYPT_FLAG_VERIFY;
353         const int verify_if_possible = how2verify & CRYPT_FLAG_VERIFY_IF_POSSIBLE;
354         char *pass = NULL;
355         int newline_stop;
356         int read_horizon;
357
358         if(key_file && !strcmp(key_file, "-")) {
359                 /* Allow binary reading from stdin */
360                 fd = passphrase_fd;
361                 newline_stop = 0;
362                 read_horizon = 0;
363         } else if (key_file) {
364                 fd = open(key_file, O_RDONLY);
365                 if (fd < 0) {
366                         char buf[128];
367                         set_error("Error opening key file: %s",
368                                   strerror_r(errno, buf, 128));
369                         goto out_err;
370                 }
371                 newline_stop = 0;
372
373                 /* This can either be 0 (LUKS) or the actually number
374                  * of key bytes (default or passed by -s) */
375                 read_horizon = key_size;
376         } else {
377                 fd = passphrase_fd;
378                 newline_stop = 1;
379                 read_horizon = 0;   /* Infinite, if read from terminal or fd */
380         }       
381
382         /* Interactive case */
383         if(isatty(fd)) {
384                 int i;
385
386                 pass = safe_alloc(512);
387                 if (!pass || (i = interactive_pass(prompt, pass, 512, timeout))) {
388                         set_error("Error reading passphrase");
389                         goto out_err;
390                 }
391                 if (verify || verify_if_possible) {
392                         char pass_verify[512];
393                         i = interactive_pass("Verify passphrase: ", pass_verify, sizeof(pass_verify), timeout);
394                         if (i || strcmp(pass, pass_verify) != 0) {
395                                 set_error("Passphrases do not match");
396                                 goto out_err;
397                         }
398                         memset(pass_verify, 0, sizeof(pass_verify));
399                 }
400                 *passLen = strlen(pass);
401                 *key = pass;
402         } else {
403                 /* 
404                  * This is either a fd-input or a file, in neither case we can verify the input,
405                  * however we don't stop on new lines if it's a binary file.
406                  */
407                 int buflen, i;
408
409                 if(verify) {
410                         set_error("Can't do passphrase verification on non-tty inputs");
411                         goto out_err;
412                 }
413                 /* The following for control loop does an exhausting
414                  * read on the key material file, if requested with
415                  * key_size == 0, as it's done by LUKS. However, we
416                  * should warn the user, if it's a non-regular file,
417                  * such as /dev/random, because in this case, the loop
418                  * will read forever.
419                  */ 
420                 if(key_file && strcmp(key_file, "-") && read_horizon == 0) {
421                         struct stat st;
422                         if(stat(key_file, &st) < 0) {
423                                 set_error("Can't stat key file");
424                                 goto out_err;
425                         }
426                         if(!S_ISREG(st.st_mode)) {
427                                 //                              set_error("Can't do exhausting read on non regular files");
428                                 // goto out_err;
429                                 fprintf(stderr,"Warning: exhausting read requested, but key file is not a regular file, function might never return.\n");
430                         }
431                 }
432                 buflen = 0;
433                 for(i = 0; read_horizon == 0 || i < read_horizon; i++) {
434                         if(i >= buflen - 1) {
435                                 buflen += 128;
436                                 pass = safe_realloc(pass, buflen);
437                                 if (!pass) {
438                                         set_error("Not enough memory while "
439                                                   "reading passphrase");
440                                         goto out_err;
441                                 }
442                         }
443                         if(read(fd, pass + i, 1) != 1 || (newline_stop && pass[i] == '\n'))
444                                 break;
445                 }
446                 if(key_file)
447                         close(fd);
448                 pass[i] = 0;
449                 *key = pass;
450                 *passLen = i;
451         }
452
453         return isatty(fd); /* Return true, when password reading can be tried on interactive fds */
454
455 out_err:
456         if(pass)
457                 safe_free(pass);
458         *key = NULL;
459         *passLen = 0;
460         return 0;
461 }
462