Error handling improvement thanks to Erik Edin.
[platform/upstream/cryptsetup.git] / lib / utils.c
1 #include <stdio.h>
2 #include <string.h>
3 #include <stdlib.h>
4 #include <stddef.h>
5 #include <stdarg.h>
6 #include <errno.h>
7 #include <linux/fs.h>
8 #include <sys/types.h>
9 #include <unistd.h>
10 #include <sys/types.h>
11 #include <sys/stat.h>
12 #include <sys/ioctl.h>
13 #include <fcntl.h>
14 #include <termios.h>
15
16 #include "libcryptsetup.h"
17 #include "internal.h"
18
19
20 struct safe_allocation {
21         size_t  size;
22         char    data[1];
23 };
24
25 static char *error=NULL;
26
27 void set_error_va(const char *fmt, va_list va)
28 {
29
30         if(error) {
31             free(error);
32             error=NULL;
33         }
34
35         vasprintf(&error, fmt, va);
36 }
37
38 void set_error(const char *fmt, ...)
39 {
40         va_list va;
41
42         va_start(va, fmt);
43         set_error_va(fmt, va);
44         va_end(va);
45 }
46
47 const char *get_error(void)
48 {
49         return error;
50 }
51
52 void *safe_alloc(size_t size)
53 {
54         struct safe_allocation *alloc;
55
56         if (!size)
57                 return NULL;
58
59         alloc = malloc(size + offsetof(struct safe_allocation, data));
60         if (!alloc)
61                 return NULL;
62
63         alloc->size = size;
64
65         return &alloc->data;
66 }
67
68 void safe_free(void *data)
69 {
70         struct safe_allocation *alloc;
71
72         if (!data)
73                 return;
74
75         alloc = data - offsetof(struct safe_allocation, data);
76
77         memset(data, 0, alloc->size);
78
79         alloc->size = 0x55aa55aa;
80         free(alloc);
81 }
82
83 void *safe_realloc(void *data, size_t size)
84 {
85         void *new_data;
86
87         new_data = safe_alloc(size);
88
89         if (new_data && data) {
90                 struct safe_allocation *alloc;
91
92                 alloc = data - offsetof(struct safe_allocation, data);
93
94                 if (size > alloc->size)
95                         size = alloc->size;
96
97                 memcpy(new_data, data, size);
98         }
99
100         safe_free(data);
101         return new_data;
102 }
103
104 char *safe_strdup(const char *s)
105 {
106         char *s2 = safe_alloc(strlen(s) + 1);
107
108         if (!s2)
109                 return NULL;
110
111         return strcpy(s2, s);
112 }
113
114 /* Credits go to Michal's padlock patches for this alignment code */
115
116 static void *aligned_malloc(char **base, int size, int alignment) 
117 {
118         char *ptr;
119
120         ptr  = malloc(size + alignment);
121         if(ptr == NULL) return NULL;
122
123         *base = ptr;
124         if(alignment > 1 && ((long)ptr & (alignment - 1))) {
125                 ptr += alignment - ((long)(ptr) & (alignment - 1));
126         }
127         return ptr;
128 }
129
130 static int sector_size(int fd) 
131 {
132         int bsize;
133         if (ioctl(fd,BLKSSZGET, &bsize) < 0)
134                 return -EINVAL;
135         else
136                 return bsize;
137 }
138
139 int sector_size_for_device(const char *device)
140 {
141         int fd = open(device, O_RDONLY);
142         int r;
143         if(fd < 0)
144                 return -EINVAL;
145         r = sector_size(fd);
146         close(fd);
147         return r;
148 }
149
150 ssize_t write_blockwise(int fd, const void *orig_buf, size_t count) 
151 {
152         char *padbuf; char *padbuf_base;
153         char *buf = (char *)orig_buf;
154         int r = 0;
155         int hangover; int solid; int bsize;
156
157         if ((bsize = sector_size(fd)) < 0)
158                 return bsize;
159
160         hangover = count % bsize;
161         solid = count - hangover;
162
163         padbuf = aligned_malloc(&padbuf_base, bsize, bsize);
164         if(padbuf == NULL) return -ENOMEM;
165
166         while(solid) {
167                 memcpy(padbuf, buf, bsize);
168                 r = write(fd, padbuf, bsize);
169                 if(r < 0 || r != bsize) goto out;
170
171                 solid -= bsize;
172                 buf += bsize;
173         }
174         if(hangover) {
175                 r = read(fd,padbuf,bsize);
176                 if(r < 0 || r != bsize) goto out;
177
178                 lseek(fd,-bsize,SEEK_CUR);
179                 memcpy(padbuf,buf,hangover);
180
181                 r = write(fd,padbuf, bsize);
182                 if(r < 0 || r != bsize) goto out;
183                 buf += hangover;
184         }
185  out:
186         free(padbuf_base);
187         return (buf-(char *)orig_buf)?(buf-(char *)orig_buf):r;
188
189 }
190
191 ssize_t read_blockwise(int fd, void *orig_buf, size_t count) {
192         char *padbuf; char *padbuf_base;
193         char *buf = (char *)orig_buf;
194         int r = 0;
195         int step;
196         int bsize;
197
198         if ((bsize = sector_size(fd)) < 0)
199                 return bsize;
200
201         padbuf = aligned_malloc(&padbuf_base, bsize, bsize);
202         if(padbuf == NULL) return -ENOMEM;
203
204         while(count) {
205                 r = read(fd,padbuf,bsize);
206                 if(r < 0 || r != bsize) {
207                         set_error("read failed in read_blockwise.\n");
208                         goto out;
209                 }
210                 step = count<bsize?count:bsize;
211                 memcpy(buf,padbuf,step);
212                 buf += step;
213                 count -= step;
214         }
215  out:
216         free(padbuf_base); 
217         return (buf-(char *)orig_buf)?(buf-(char *)orig_buf):r;
218 }
219
220 /* 
221  * Combines llseek with blockwise write. write_blockwise can already deal with short writes
222  * but we also need a function to deal with short writes at the start. But this information
223  * is implicitly included in the read/write offset, which can not be set to non-aligned 
224  * boundaries. Hence, we combine llseek with write.
225  */
226    
227 ssize_t write_lseek_blockwise(int fd, const char *buf, size_t count, off_t offset) {
228         int bsize = sector_size(fd);
229         const char *orig_buf = buf;
230         char frontPadBuf[bsize];
231         int frontHang = offset % bsize;
232         int r;
233
234         if (bsize < 0)
235                 return bsize;
236
237         lseek(fd, offset - frontHang, SEEK_SET);
238         if(offset % bsize) {
239                 int innerCount = count<bsize?count:bsize;
240
241                 r = read(fd,frontPadBuf,bsize);
242                 if(r < 0) return -1;
243
244                 memcpy(frontPadBuf+frontHang, buf, innerCount);
245
246                 lseek(fd, offset - frontHang, SEEK_SET);
247                 r = write(fd,frontPadBuf,bsize);
248                 if(r < 0) return -1;
249
250                 buf += innerCount;
251                 count -= innerCount;
252         }
253         if(count <= 0) return buf - orig_buf;
254
255         return write_blockwise(fd, buf, count) + innerCount;
256 }
257
258 /* Password reading helpers */
259
260 static int untimed_read(int fd, char *pass, size_t maxlen)
261 {
262         ssize_t i;
263
264         i = read(fd, pass, maxlen);
265         if (i > 0) {
266                 pass[i-1] = '\0';
267                 i = 0;
268         } else if (i == 0) { /* EOF */
269                 *pass = 0;
270                 i = -1;
271         }
272         return i;
273 }
274
275 static int timed_read(int fd, char *pass, size_t maxlen, long timeout)
276 {
277         struct timeval t;
278         fd_set fds;
279         int failed = -1;
280
281         FD_ZERO(&fds);
282         FD_SET(fd, &fds);
283         t.tv_sec = timeout;
284         t.tv_usec = 0;
285
286         if (select(fd+1, &fds, NULL, NULL, &t) > 0)
287                 failed = untimed_read(fd, pass, maxlen);
288         else
289                 set_error("Operation timed out");
290         return failed;
291 }
292
293 static int interactive_pass(const char *prompt, char *pass, size_t maxlen,
294                 long timeout)
295 {
296         struct termios orig, tmp;
297         int failed = -1;
298         int infd = STDIN_FILENO, outfd;
299
300         if (maxlen < 1)
301                 goto out_err;
302
303         /* Read and write to /dev/tty if available */
304         if ((infd = outfd = open("/dev/tty", O_RDWR)) == -1) {
305                 infd = STDIN_FILENO;
306                 outfd = STDERR_FILENO;
307         }
308
309         if (tcgetattr(infd, &orig)) {
310                 set_error("Unable to get terminal");
311                 goto out_err;
312         }
313         memcpy(&tmp, &orig, sizeof(tmp));
314         tmp.c_lflag &= ~ECHO;
315
316         write(outfd, prompt, strlen(prompt));
317         tcsetattr(infd, TCSAFLUSH, &tmp);
318         if (timeout)
319                 failed = timed_read(infd, pass, maxlen, timeout);
320         else
321                 failed = untimed_read(infd, pass, maxlen);
322         tcsetattr(infd, TCSAFLUSH, &orig);
323
324 out_err:
325         if (!failed)
326                 write(outfd, "\n", 1);
327         if (infd != STDIN_FILENO)
328                 close(infd);
329         return failed;
330 }
331
332 /*
333  * Password reading behaviour matrix of get_key
334  * 
335  *                    p   v   n   h
336  * -----------------+---+---+---+---
337  * interactive      | Y | Y | Y | Inf
338  * from fd          | N | N | Y | Inf
339  * from binary file | N | N | N | Inf or options->key_size
340  *
341  * Legend: p..prompt, v..can verify, n..newline-stop, h..read horizon
342  *
343  * Note: --key-file=- is interpreted as a read from a binary file (stdin)
344  *
345  * Returns true when more keys are available (that is when password
346  * reading can be retried as for interactive terminals).
347  */
348
349 int get_key(char *prompt, char **key, int *passLen, int key_size, const char *key_file, int passphrase_fd, int timeout, int how2verify)
350 {
351         int fd;
352         const int verify = how2verify & CRYPT_FLAG_VERIFY;
353         const int verify_if_possible = how2verify & CRYPT_FLAG_VERIFY_IF_POSSIBLE;
354         char *pass = NULL;
355         int newline_stop;
356         int read_horizon;
357
358         if(key_file && !strcmp(key_file, "-")) {
359                 /* Allow binary reading from stdin */
360                 fd = passphrase_fd;
361                 newline_stop = 0;
362                 read_horizon = 0;
363         } else if (key_file) {
364                 fd = open(key_file, O_RDONLY);
365                 if (fd < 0) {
366                         char buf[128];
367                         set_error("Error opening key file: %s",
368                                   strerror_r(errno, buf, 128));
369                         goto out_err;
370                 }
371                 newline_stop = 0;
372
373                 /* This can either be 0 (LUKS) or the actually number
374                  * of key bytes (default or passed by -s) */
375                 read_horizon = key_size;
376         } else {
377                 fd = passphrase_fd;
378                 newline_stop = 1;
379                 read_horizon = 0;   /* Infinite, if read from terminal or fd */
380         }       
381
382         /* Interactive case */
383         if(isatty(fd)) {
384                 int i;
385
386                 pass = safe_alloc(512);
387                 if (!pass || (i = interactive_pass(prompt, pass, 512, timeout))) {
388                         set_error("Error reading passphrase");
389                         goto out_err;
390                 }
391                 if (verify || verify_if_possible) {
392                         char pass_verify[512];
393                         i = interactive_pass("Verify passphrase: ", pass_verify, sizeof(pass_verify), timeout);
394                         if (i || strcmp(pass, pass_verify) != 0) {
395                                 set_error("Passphrases do not match");
396                                 goto out_err;
397                         }
398                         memset(pass_verify, 0, sizeof(pass_verify));
399                 }
400                 *passLen = strlen(pass);
401                 *key = pass;
402         } else {
403                 /* 
404                  * This is either a fd-input or a file, in neither case we can verify the input,
405                  * however we don't stop on new lines if it's a binary file.
406                  */
407                 int buflen, i;
408
409                 if(verify) {
410                         set_error("Can't do passphrase verification on non-tty inputs");
411                         goto out_err;
412                 }
413                 /* The following for control loop does an exhausting
414                  * read on the key material file, if requested with
415                  * key_size == 0, as it's done by LUKS. However, we
416                  * should warn the user, if it's a non-regular file,
417                  * such as /dev/random, because in this case, the loop
418                  * will read forever.
419                  */ 
420                 if(key_file && strcmp(key_file, "-") && read_horizon == 0) {
421                         struct stat st;
422                         if(stat(key_file, &st) < 0) {
423                                 set_error("Can't stat key file");
424                                 goto out_err;
425                         }
426                         if(!S_ISREG(st.st_mode)) {
427                                 //                              set_error("Can't do exhausting read on non regular files");
428                                 // goto out_err;
429                                 fprintf(stderr,"Warning: exhausting read requested, but key file is not a regular file, function might never return.\n");
430                         }
431                 }
432                 buflen = 0;
433                 for(i = 0; read_horizon == 0 || i < read_horizon; i++) {
434                         if(i >= buflen - 1) {
435                                 buflen += 128;
436                                 pass = safe_realloc(pass, buflen);
437                                 if (!pass) {
438                                         set_error("Not enough memory while "
439                                                   "reading passphrase");
440                                         goto out_err;
441                                 }
442                         }
443                         if(read(fd, pass + i, 1) != 1 || (newline_stop && pass[i] == '\n'))
444                                 break;
445                 }
446                 if(key_file)
447                         close(fd);
448                 pass[i] = 0;
449                 *key = pass;
450                 *passLen = i;
451         }
452
453         return isatty(fd); /* Return true, when password reading can be tried on interactive fds */
454
455 out_err:
456         if(pass)
457                 safe_free(pass);
458         *key = NULL;
459         *passLen = 0;
460         return 0;
461 }
462