Merge tag 'asoc-fix-v5.15-rc5' of https://git.kernel.org/pub/scm/linux/kernel/git...
[platform/kernel/linux-starfive.git] / mm / maccess.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Access kernel or user memory without faulting.
4  */
5 #include <linux/export.h>
6 #include <linux/mm.h>
7 #include <linux/uaccess.h>
8
9 bool __weak copy_from_kernel_nofault_allowed(const void *unsafe_src,
10                 size_t size)
11 {
12         return true;
13 }
14
15 #ifdef HAVE_GET_KERNEL_NOFAULT
16
17 #define copy_from_kernel_nofault_loop(dst, src, len, type, err_label)   \
18         while (len >= sizeof(type)) {                                   \
19                 __get_kernel_nofault(dst, src, type, err_label);                \
20                 dst += sizeof(type);                                    \
21                 src += sizeof(type);                                    \
22                 len -= sizeof(type);                                    \
23         }
24
25 long copy_from_kernel_nofault(void *dst, const void *src, size_t size)
26 {
27         unsigned long align = 0;
28
29         if (!IS_ENABLED(CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS))
30                 align = (unsigned long)dst | (unsigned long)src;
31
32         if (!copy_from_kernel_nofault_allowed(src, size))
33                 return -ERANGE;
34
35         pagefault_disable();
36         if (!(align & 7))
37                 copy_from_kernel_nofault_loop(dst, src, size, u64, Efault);
38         if (!(align & 3))
39                 copy_from_kernel_nofault_loop(dst, src, size, u32, Efault);
40         if (!(align & 1))
41                 copy_from_kernel_nofault_loop(dst, src, size, u16, Efault);
42         copy_from_kernel_nofault_loop(dst, src, size, u8, Efault);
43         pagefault_enable();
44         return 0;
45 Efault:
46         pagefault_enable();
47         return -EFAULT;
48 }
49 EXPORT_SYMBOL_GPL(copy_from_kernel_nofault);
50
51 #define copy_to_kernel_nofault_loop(dst, src, len, type, err_label)     \
52         while (len >= sizeof(type)) {                                   \
53                 __put_kernel_nofault(dst, src, type, err_label);                \
54                 dst += sizeof(type);                                    \
55                 src += sizeof(type);                                    \
56                 len -= sizeof(type);                                    \
57         }
58
59 long copy_to_kernel_nofault(void *dst, const void *src, size_t size)
60 {
61         unsigned long align = 0;
62
63         if (!IS_ENABLED(CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS))
64                 align = (unsigned long)dst | (unsigned long)src;
65
66         pagefault_disable();
67         if (!(align & 7))
68                 copy_to_kernel_nofault_loop(dst, src, size, u64, Efault);
69         if (!(align & 3))
70                 copy_to_kernel_nofault_loop(dst, src, size, u32, Efault);
71         if (!(align & 1))
72                 copy_to_kernel_nofault_loop(dst, src, size, u16, Efault);
73         copy_to_kernel_nofault_loop(dst, src, size, u8, Efault);
74         pagefault_enable();
75         return 0;
76 Efault:
77         pagefault_enable();
78         return -EFAULT;
79 }
80
81 long strncpy_from_kernel_nofault(char *dst, const void *unsafe_addr, long count)
82 {
83         const void *src = unsafe_addr;
84
85         if (unlikely(count <= 0))
86                 return 0;
87         if (!copy_from_kernel_nofault_allowed(unsafe_addr, count))
88                 return -ERANGE;
89
90         pagefault_disable();
91         do {
92                 __get_kernel_nofault(dst, src, u8, Efault);
93                 dst++;
94                 src++;
95         } while (dst[-1] && src - unsafe_addr < count);
96         pagefault_enable();
97
98         dst[-1] = '\0';
99         return src - unsafe_addr;
100 Efault:
101         pagefault_enable();
102         dst[-1] = '\0';
103         return -EFAULT;
104 }
105 #else /* HAVE_GET_KERNEL_NOFAULT */
106 /**
107  * copy_from_kernel_nofault(): safely attempt to read from kernel-space
108  * @dst: pointer to the buffer that shall take the data
109  * @src: address to read from
110  * @size: size of the data chunk
111  *
112  * Safely read from kernel address @src to the buffer at @dst.  If a kernel
113  * fault happens, handle that and return -EFAULT.  If @src is not a valid kernel
114  * address, return -ERANGE.
115  *
116  * We ensure that the copy_from_user is executed in atomic context so that
117  * do_page_fault() doesn't attempt to take mmap_lock.  This makes
118  * copy_from_kernel_nofault() suitable for use within regions where the caller
119  * already holds mmap_lock, or other locks which nest inside mmap_lock.
120  */
121 long copy_from_kernel_nofault(void *dst, const void *src, size_t size)
122 {
123         long ret;
124         mm_segment_t old_fs = get_fs();
125
126         if (!copy_from_kernel_nofault_allowed(src, size))
127                 return -ERANGE;
128
129         set_fs(KERNEL_DS);
130         pagefault_disable();
131         ret = __copy_from_user_inatomic(dst, (__force const void __user *)src,
132                         size);
133         pagefault_enable();
134         set_fs(old_fs);
135
136         if (ret)
137                 return -EFAULT;
138         return 0;
139 }
140 EXPORT_SYMBOL_GPL(copy_from_kernel_nofault);
141
142 /**
143  * copy_to_kernel_nofault(): safely attempt to write to a location
144  * @dst: address to write to
145  * @src: pointer to the data that shall be written
146  * @size: size of the data chunk
147  *
148  * Safely write to address @dst from the buffer at @src.  If a kernel fault
149  * happens, handle that and return -EFAULT.
150  */
151 long copy_to_kernel_nofault(void *dst, const void *src, size_t size)
152 {
153         long ret;
154         mm_segment_t old_fs = get_fs();
155
156         set_fs(KERNEL_DS);
157         pagefault_disable();
158         ret = __copy_to_user_inatomic((__force void __user *)dst, src, size);
159         pagefault_enable();
160         set_fs(old_fs);
161
162         if (ret)
163                 return -EFAULT;
164         return 0;
165 }
166
167 /**
168  * strncpy_from_kernel_nofault: - Copy a NUL terminated string from unsafe
169  *                               address.
170  * @dst:   Destination address, in kernel space.  This buffer must be at
171  *         least @count bytes long.
172  * @unsafe_addr: Unsafe address.
173  * @count: Maximum number of bytes to copy, including the trailing NUL.
174  *
175  * Copies a NUL-terminated string from unsafe address to kernel buffer.
176  *
177  * On success, returns the length of the string INCLUDING the trailing NUL.
178  *
179  * If access fails, returns -EFAULT (some data may have been copied and the
180  * trailing NUL added).  If @unsafe_addr is not a valid kernel address, return
181  * -ERANGE.
182  *
183  * If @count is smaller than the length of the string, copies @count-1 bytes,
184  * sets the last byte of @dst buffer to NUL and returns @count.
185  */
186 long strncpy_from_kernel_nofault(char *dst, const void *unsafe_addr, long count)
187 {
188         mm_segment_t old_fs = get_fs();
189         const void *src = unsafe_addr;
190         long ret;
191
192         if (unlikely(count <= 0))
193                 return 0;
194         if (!copy_from_kernel_nofault_allowed(unsafe_addr, count))
195                 return -ERANGE;
196
197         set_fs(KERNEL_DS);
198         pagefault_disable();
199
200         do {
201                 ret = __get_user(*dst++, (const char __user __force *)src++);
202         } while (dst[-1] && ret == 0 && src - unsafe_addr < count);
203
204         dst[-1] = '\0';
205         pagefault_enable();
206         set_fs(old_fs);
207
208         return ret ? -EFAULT : src - unsafe_addr;
209 }
210 #endif /* HAVE_GET_KERNEL_NOFAULT */
211
212 /**
213  * copy_from_user_nofault(): safely attempt to read from a user-space location
214  * @dst: pointer to the buffer that shall take the data
215  * @src: address to read from. This must be a user address.
216  * @size: size of the data chunk
217  *
218  * Safely read from user address @src to the buffer at @dst. If a kernel fault
219  * happens, handle that and return -EFAULT.
220  */
221 long copy_from_user_nofault(void *dst, const void __user *src, size_t size)
222 {
223         long ret = -EFAULT;
224         mm_segment_t old_fs = force_uaccess_begin();
225
226         if (access_ok(src, size)) {
227                 pagefault_disable();
228                 ret = __copy_from_user_inatomic(dst, src, size);
229                 pagefault_enable();
230         }
231         force_uaccess_end(old_fs);
232
233         if (ret)
234                 return -EFAULT;
235         return 0;
236 }
237 EXPORT_SYMBOL_GPL(copy_from_user_nofault);
238
239 /**
240  * copy_to_user_nofault(): safely attempt to write to a user-space location
241  * @dst: address to write to
242  * @src: pointer to the data that shall be written
243  * @size: size of the data chunk
244  *
245  * Safely write to address @dst from the buffer at @src.  If a kernel fault
246  * happens, handle that and return -EFAULT.
247  */
248 long copy_to_user_nofault(void __user *dst, const void *src, size_t size)
249 {
250         long ret = -EFAULT;
251         mm_segment_t old_fs = force_uaccess_begin();
252
253         if (access_ok(dst, size)) {
254                 pagefault_disable();
255                 ret = __copy_to_user_inatomic(dst, src, size);
256                 pagefault_enable();
257         }
258         force_uaccess_end(old_fs);
259
260         if (ret)
261                 return -EFAULT;
262         return 0;
263 }
264 EXPORT_SYMBOL_GPL(copy_to_user_nofault);
265
266 /**
267  * strncpy_from_user_nofault: - Copy a NUL terminated string from unsafe user
268  *                              address.
269  * @dst:   Destination address, in kernel space.  This buffer must be at
270  *         least @count bytes long.
271  * @unsafe_addr: Unsafe user address.
272  * @count: Maximum number of bytes to copy, including the trailing NUL.
273  *
274  * Copies a NUL-terminated string from unsafe user address to kernel buffer.
275  *
276  * On success, returns the length of the string INCLUDING the trailing NUL.
277  *
278  * If access fails, returns -EFAULT (some data may have been copied
279  * and the trailing NUL added).
280  *
281  * If @count is smaller than the length of the string, copies @count-1 bytes,
282  * sets the last byte of @dst buffer to NUL and returns @count.
283  */
284 long strncpy_from_user_nofault(char *dst, const void __user *unsafe_addr,
285                               long count)
286 {
287         mm_segment_t old_fs;
288         long ret;
289
290         if (unlikely(count <= 0))
291                 return 0;
292
293         old_fs = force_uaccess_begin();
294         pagefault_disable();
295         ret = strncpy_from_user(dst, unsafe_addr, count);
296         pagefault_enable();
297         force_uaccess_end(old_fs);
298
299         if (ret >= count) {
300                 ret = count;
301                 dst[ret - 1] = '\0';
302         } else if (ret > 0) {
303                 ret++;
304         }
305
306         return ret;
307 }
308
309 /**
310  * strnlen_user_nofault: - Get the size of a user string INCLUDING final NUL.
311  * @unsafe_addr: The string to measure.
312  * @count: Maximum count (including NUL)
313  *
314  * Get the size of a NUL-terminated string in user space without pagefault.
315  *
316  * Returns the size of the string INCLUDING the terminating NUL.
317  *
318  * If the string is too long, returns a number larger than @count. User
319  * has to check the return value against "> count".
320  * On exception (or invalid count), returns 0.
321  *
322  * Unlike strnlen_user, this can be used from IRQ handler etc. because
323  * it disables pagefaults.
324  */
325 long strnlen_user_nofault(const void __user *unsafe_addr, long count)
326 {
327         mm_segment_t old_fs;
328         int ret;
329
330         old_fs = force_uaccess_begin();
331         pagefault_disable();
332         ret = strnlen_user(unsafe_addr, count);
333         pagefault_enable();
334         force_uaccess_end(old_fs);
335
336         return ret;
337 }