Fix CVE-2017-6891 in minitasn1 code
[platform/upstream/gnutls.git] / src / psk.c
1 /*
2  * Copyright (C) 2005-2012 Free Software Foundation, Inc.
3  *
4  * This file is part of GnuTLS.
5  *
6  * GnuTLS is free software: you can redistribute it and/or modify it
7  * under the terms of the GNU General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * GnuTLS is distributed in the hope that it will be useful, but
12  * WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * General Public License for more details.
15  *
16  * You should have received a copy of the GNU General Public License
17  * along with this program.  If not, see
18  * <http://www.gnu.org/licenses/>.
19  */
20
21 #include <config.h>
22
23 /* Gnulib portability files. */
24
25 #ifndef ENABLE_PSK
26
27 #include <stdio.h>
28
29 int main(int argc, char **argv)
30 {
31         printf("\nPSK not supported. This program is a dummy.\n\n");
32         return 1;
33 };
34
35 #else
36
37 #include <stdio.h>
38 #include <string.h>
39 #include <stdlib.h>
40 #include <gnutls/gnutls.h>
41 #include <psktool-args.h>
42
43 #include <gnutls/crypto.h>      /* for random */
44
45 #include <sys/types.h>
46 #include <sys/stat.h>
47
48 #ifndef _WIN32
49 #include <pwd.h>
50 #include <unistd.h>
51 #else
52 #include <windows.h>
53 #endif
54
55 /* Gnulib portability files. */
56 #include <minmax.h>
57 #include "getpass.h"
58
59 static int write_key(const char *username, const char *key, int key_size,
60                      const char *passwd_file);
61
62 #define KPASSWD "/etc/passwd.psk"
63 #define MAX_KEY_SIZE 64
64 int main(int argc, char **argv)
65 {
66         int ret;
67 #ifndef _WIN32
68         struct passwd *pwd;
69 #endif
70         unsigned char key[MAX_KEY_SIZE];
71         char hex_key[MAX_KEY_SIZE * 2 + 1];
72         int optct, key_size;
73         gnutls_datum_t dkey;
74         const char *passwd, *username;
75         size_t hex_key_size = sizeof(hex_key);
76
77         if ((ret = gnutls_global_init()) < 0) {
78                 fprintf(stderr, "global_init: %s\n", gnutls_strerror(ret));
79                 exit(1);
80         }
81
82         umask(066);
83
84         optct = optionProcess(&psktoolOptions, argc, argv);
85         argc -= optct;
86         argv += optct;
87
88         if (!HAVE_OPT(PASSWD))
89                 passwd = (char *) KPASSWD;
90         else
91                 passwd = OPT_ARG(PASSWD);
92
93         if (!HAVE_OPT(USERNAME)) {
94 #ifndef _WIN32
95                 pwd = getpwuid(getuid());
96
97                 if (pwd == NULL) {
98                         fprintf(stderr, "No such user\n");
99                         return -1;
100                 }
101
102                 username = pwd->pw_name;
103 #else
104                 fprintf(stderr, "Please specify a user\n");
105                 return -1;
106 #endif
107         } else
108                 username = OPT_ARG(USERNAME);
109
110         if (HAVE_OPT(KEYSIZE) && OPT_VALUE_KEYSIZE > MAX_KEY_SIZE) {
111                 fprintf(stderr, "Key size is too long\n");
112                 exit(1);
113         }
114
115         if (!HAVE_OPT(KEYSIZE) || OPT_VALUE_KEYSIZE < 1)
116                 key_size = 16;
117         else
118                 key_size = OPT_VALUE_KEYSIZE;
119
120         printf("Generating a random key for user '%s'\n", username);
121
122         ret = gnutls_rnd(GNUTLS_RND_RANDOM, (char *) key, key_size);
123         if (ret < 0) {
124                 fprintf(stderr, "Not enough randomness\n");
125                 exit(1);
126         }
127
128         dkey.data = key;
129         dkey.size = key_size;
130
131         ret = gnutls_hex_encode(&dkey, hex_key, &hex_key_size);
132         if (ret < 0) {
133                 fprintf(stderr, "HEX encoding error\n");
134                 exit(1);
135         }
136
137         ret = write_key(username, hex_key, hex_key_size, passwd);
138         if (ret == 0)
139                 printf("Key stored to %s\n", passwd);
140
141         return ret;
142 }
143
144 static int filecopy(const char *src, const char *dst)
145 {
146         FILE *fd, *fd2;
147         char line[5 * 1024];
148         char *p;
149
150         fd = fopen(dst, "w");
151         if (fd == NULL) {
152                 fprintf(stderr, "Cannot open '%s' for write\n", dst);
153                 return -1;
154         }
155
156         fd2 = fopen(src, "r");
157         if (fd2 == NULL) {
158                 /* empty file */
159                 fclose(fd);
160                 return 0;
161         }
162
163         line[sizeof(line) - 1] = 0;
164         do {
165                 p = fgets(line, sizeof(line) - 1, fd2);
166                 if (p == NULL)
167                         break;
168
169                 fputs(line, fd);
170         }
171         while (1);
172
173         fclose(fd);
174         fclose(fd2);
175
176         return 0;
177 }
178
179 static int
180 write_key(const char *username, const char *key, int key_size,
181           const char *passwd_file)
182 {
183         FILE *fd;
184         char line[5 * 1024];
185         char *p, *pp;
186         char tmpname[1024];
187
188
189         /* delete previous entry */
190         struct stat st;
191         FILE *fd2;
192         int put;
193
194         if (strlen(passwd_file) + 5 > sizeof(tmpname)) {
195                 fprintf(stderr, "file '%s' is tooooo long\n", passwd_file);
196                 return -1;
197         }
198
199         snprintf(tmpname, sizeof(tmpname), "%s.tmp", passwd_file);
200
201         if (stat(tmpname, &st) != -1) {
202                 fprintf(stderr, "file '%s' is locked\n", tmpname);
203                 return -1;
204         }
205
206         if (filecopy(passwd_file, tmpname) != 0) {
207                 fprintf(stderr, "Cannot copy '%s' to '%s'\n", passwd_file,
208                         tmpname);
209                 return -1;
210         }
211
212         fd = fopen(passwd_file, "w");
213         if (fd == NULL) {
214                 fprintf(stderr, "Cannot open '%s' for write\n",
215                         passwd_file);
216                 remove(tmpname);
217                 return -1;
218         }
219
220         fd2 = fopen(tmpname, "r");
221         if (fd2 == NULL) {
222                 fprintf(stderr, "Cannot open '%s' for read\n", tmpname);
223                 remove(tmpname);
224                 return -1;
225         }
226
227         put = 0;
228         do {
229                 p = fgets(line, sizeof(line) - 1, fd2);
230                 if (p == NULL)
231                         break;
232
233                 pp = strchr(line, ':');
234                 if (pp == NULL)
235                         continue;
236
237                 if (strncmp(p, username,
238                             MAX(strlen(username),
239                                 (unsigned int) (pp - p))) == 0) {
240                         put = 1;
241                         fprintf(fd, "%s:%s\n", username, key);
242                 } else {
243                         fputs(line, fd);
244                 }
245         }
246         while (1);
247
248         if (put == 0) {
249                 fprintf(fd, "%s:%s\n", username, key);
250         }
251
252         fclose(fd);
253         fclose(fd2);
254
255         remove(tmpname);
256
257
258         return 0;
259 }
260
261 #endif                          /* ENABLE_PSK */