[REFACTOR] move and rename /un/register_us_page_probe()
[kernel/swap-modules.git] / driver / helper.c
1 #include <dbi_kprobes.h>
2 #include <dbi_kprobes_deps.h>
3 #include <ksyms.h>
4 #include "us_proc_inst.h"
5 #include "us_slot_manager.h"
6 #include "storage.h"
7 #include "sspt/sspt.h"
8
9 /*
10  ******************************************************************************
11  *                               do_page_fault()                              *
12  ******************************************************************************
13  */
14
15 struct pf_data {
16         unsigned long addr;
17 };
18
19 static int entry_handler_pf(struct kretprobe_instance *ri, struct pt_regs *regs)
20 {
21         struct pf_data *data = (struct pf_data *)ri->data;
22
23 #ifdef CONFIG_X86
24         data->addr = read_cr2();
25 #elif CONFIG_ARM
26         data->addr = regs->ARM_r0;
27 #else
28 #error this architecture is not supported
29 #endif
30
31         return 0;
32 }
33
34 /* Detects when IPs are really loaded into phy mem and installs probes. */
35 static int ret_handler_pf(struct kretprobe_instance *ri, struct pt_regs *regs)
36 {
37         struct task_struct *task = current->group_leader;
38         struct mm_struct *mm = task->mm;
39         struct sspt_procs *procs = NULL;
40         /*
41          * Because process threads have same address space
42          * we instrument only group_leader of all this threads
43          */
44         struct pf_data *data;
45         unsigned long addr = 0;
46         int valid_addr;
47
48         if (task->flags & PF_KTHREAD) {
49                 goto out;
50         }
51
52         if (!is_us_instrumentation()) {
53                 goto out;
54         }
55
56         data = (struct pf_data *)ri->data;
57         addr = data->addr;
58
59         valid_addr = mm && page_present(mm, addr);
60         if (!valid_addr) {
61                 goto out;
62         }
63
64         if (is_libonly()) {
65                 procs = sspt_procs_get_by_task_or_new(task);
66         } else {
67                 // find task
68                 if (us_proc_info.tgid == 0) {
69                         pid_t tgid = find_proc_by_task(task, us_proc_info.m_f_dentry);
70                         if (tgid) {
71                                 us_proc_info.tgid = gl_nNotifyTgid = tgid;
72                                 procs = sspt_procs_get_by_task_or_new(task);
73
74                                 /* install probes in already mapped memory */
75                                 install_proc_probes(task, procs);
76                         }
77                 }
78
79                 if (us_proc_info.tgid == task->tgid) {
80                         procs = sspt_procs_get_by_task_or_new(task);
81                 }
82         }
83
84         if (procs) {
85                 unsigned long page = addr & PAGE_MASK;
86                 install_page_probes(page, task, procs);
87         }
88
89 out:
90         return 0;
91 }
92
93 static struct kretprobe pf_kretprobe = {
94         .entry_handler = entry_handler_pf,
95         .handler = ret_handler_pf,
96         .data_size = sizeof(struct pf_data)
97 };
98
99
100
101 /*
102  ******************************************************************************
103  *                              copy_process()                                *
104  ******************************************************************************
105  */
106
107 static void recover_child(struct task_struct *child_task, struct sspt_procs *procs)
108 {
109         uninstall_us_proc_probes(child_task, procs, US_DISARM);
110         dbi_disarm_urp_inst_for_task(current, child_task);
111 }
112
113 static void rm_uprobes_child(struct task_struct *task)
114 {
115         struct sspt_procs *procs = sspt_procs_get_by_task(current);
116         if(procs) {
117                 recover_child(task, procs);
118         }
119 }
120
121 /* Delete uprobs in children at fork */
122 static int ret_handler_cp(struct kretprobe_instance *ri, struct pt_regs *regs)
123 {
124         struct task_struct* task = (struct task_struct *)regs_return_value(regs);
125
126         if(!task || IS_ERR(task))
127                 goto out;
128
129         if(task->mm != current->mm)     /* check flags CLONE_VM */
130                 rm_uprobes_child(task);
131
132 out:
133         return 0;
134 }
135
136 static struct kretprobe cp_kretprobe = {
137         .handler = ret_handler_cp,
138 };
139
140
141
142 /*
143  ******************************************************************************
144  *                                mm_release()                                *
145  ******************************************************************************
146  */
147
148 /* Detects when target process removes IPs. */
149 static int mr_pre_handler(struct kprobe *p, struct pt_regs *regs)
150 {
151         struct sspt_procs *procs = NULL;
152         struct task_struct *task = (struct task_struct *)regs->ARM_r0; /* for ARM */
153
154         if (!is_us_instrumentation() || task->tgid != task->pid) {
155                 goto out;
156         }
157
158         if (is_libonly()) {
159                 procs = sspt_procs_get_by_task(task);
160         } else {
161                 if (task->tgid == us_proc_info.tgid) {
162                         procs = sspt_procs_get_by_task(task);
163                         us_proc_info.tgid = 0;
164                 }
165         }
166
167         if (procs) {
168                 int ret = uninstall_us_proc_probes(task, procs, US_UNREGS_PROBE);
169                 if (ret != 0) {
170                         printk("failed to uninstall IPs (%d)!\n", ret);
171                 }
172
173                 dbi_unregister_all_uprobes(task);
174         }
175
176 out:
177         return 0;
178 }
179
180 static struct kprobe mr_kprobe = {
181         .pre_handler = mr_pre_handler
182 };
183
184
185
186 /*
187  ******************************************************************************
188  *                                 do_munmap()                                *
189  ******************************************************************************
190  */
191
192 static int remove_unmap_probes(struct task_struct *task, struct sspt_procs *procs, unsigned long start, size_t len)
193 {
194         struct mm_struct *mm = task->mm;
195         struct vm_area_struct *vma;
196
197         if ((start & ~PAGE_MASK) || start > TASK_SIZE || len > TASK_SIZE - start) {
198                 return -EINVAL;
199         }
200
201         if ((len = PAGE_ALIGN(len)) == 0) {
202                 return -EINVAL;
203         }
204
205         vma = find_vma(mm, start);
206         if (vma && check_vma(vma)) {
207                 struct sspt_file *file;
208                 unsigned long end = start + len;
209                 struct dentry *dentry = vma->vm_file->f_dentry;
210
211                 file = sspt_procs_find_file(procs, dentry);
212                 if (file) {
213                         if (vma->vm_start == start || vma->vm_end == end) {
214                                 unregister_us_file_probes(task, file, US_UNREGS_PROBE);
215                                 file->loaded = 0;
216                         } else {
217                                 unsigned long page_addr;
218                                 struct sspt_page *page;
219
220                                 for (page_addr = vma->vm_start; page_addr < vma->vm_end; page_addr += PAGE_SIZE) {
221                                         page = sspt_find_page_mapped(file, page_addr);
222                                         if (page) {
223                                                 sspt_unregister_page(page, US_UNREGS_PROBE, task);
224                                         }
225                                 }
226
227                                 if (check_install_pages_in_file(task, file)) {
228                                         file->loaded = 0;
229                                 }
230                         }
231                 }
232         }
233
234         return 0;
235 }
236
237 /* Detects when target removes IPs. */
238 static int unmap_pre_handler(struct kprobe *p, struct pt_regs *regs)
239 {
240         /* for ARM */
241         struct mm_struct *mm = (struct mm_struct *)regs->ARM_r0;
242         unsigned long start = regs->ARM_r1;
243         size_t len = (size_t)regs->ARM_r2;
244
245         struct sspt_procs *procs = NULL;
246         struct task_struct *task = current;
247
248         //if user-space instrumentation is not set
249         if (!is_us_instrumentation()) {
250                 goto out;
251         }
252
253         procs = sspt_procs_get_by_task(task);
254         if (procs) {
255                 if (remove_unmap_probes(task, procs, start, len)) {
256                         printk("ERROR do_munmap: start=%lx, len=%x\n", start, len);
257                 }
258         }
259
260 out:
261         return 0;
262 }
263
264 static struct kprobe unmap_kprobe = {
265         .pre_handler = unmap_pre_handler
266 };
267
268
269
270 int register_helper(void)
271 {
272         int ret = 0;
273
274         /* install kprobe on 'do_munmap' to detect when for remove user space probes */
275         ret = dbi_register_kprobe(&unmap_kprobe);
276         if (ret) {
277                 printk("dbi_register_kprobe(do_munmap) result=%d!\n", ret);
278                 return ret;
279         }
280
281         /* install kprobe on 'mm_release' to detect when for remove user space probes */
282         ret = dbi_register_kprobe(&mr_kprobe);
283         if (ret != 0) {
284                 printk("dbi_register_kprobe(mm_release) result=%d!\n", ret);
285                 goto unregister_unmap;
286         }
287
288
289         /* install kretprobe on 'copy_process' */
290         ret = dbi_register_kretprobe(&cp_kretprobe);
291         if (ret) {
292                 printk("dbi_register_kretprobe(copy_process) result=%d!\n", ret);
293                 goto unregister_mr;
294         }
295
296         /* install kretprobe on 'do_page_fault' to detect when they will be loaded */
297         ret = dbi_register_kretprobe(&pf_kretprobe);
298         if (ret) {
299                 printk("dbi_register_kretprobe(do_page_fault) result=%d!\n", ret);
300                 goto unregister_cp;
301         }
302
303         return ret;
304
305 unregister_cp:
306         dbi_unregister_kretprobe(&cp_kretprobe);
307
308 unregister_mr:
309         dbi_unregister_kprobe(&mr_kprobe, NULL);
310
311 unregister_unmap:
312         dbi_unregister_kprobe(&unmap_kprobe, NULL);
313
314         return ret;
315 }
316
317 void unregister_helper(void)
318 {
319         /* uninstall kretprobe with 'do_page_fault' */
320         dbi_unregister_kretprobe(&pf_kretprobe);
321
322         /* uninstall kretprobe with 'copy_process' */
323         dbi_unregister_kretprobe(&cp_kretprobe);
324
325         /* uninstall kprobe with 'mm_release' */
326         dbi_unregister_kprobe(&mr_kprobe, NULL);
327
328         /* uninstall kprobe with 'do_munmap' */
329         dbi_unregister_kprobe(&unmap_kprobe, NULL);
330 }
331
332 int init_helper(void)
333 {
334         unsigned long addr;
335         addr = swap_ksyms("do_page_fault");
336         if (addr == 0) {
337                 printk("Cannot find address for page fault function!\n");
338                 return -EINVAL;
339         }
340         pf_kretprobe.kp.addr = (kprobe_opcode_t *)addr;
341
342         addr = swap_ksyms("copy_process");
343         if (addr == 0) {
344                 printk("Cannot find address for copy_process function!\n");
345                 return -EINVAL;
346         }
347         cp_kretprobe.kp.addr = (kprobe_opcode_t *)addr;
348
349         addr = swap_ksyms("mm_release");
350         if (addr == 0) {
351                 printk("Cannot find address for mm_release function!\n");
352                 return -EINVAL;
353         }
354         mr_kprobe.addr = (kprobe_opcode_t *)addr;
355
356         addr = swap_ksyms("do_munmap");
357         if (addr == 0) {
358                 printk("Cannot find address for do_munmap function!\n");
359                 return -EINVAL;
360         }
361         unmap_kprobe.addr = (kprobe_opcode_t *)addr;
362
363         return 0;
364 }
365
366 void uninit_helper(void)
367 {
368 }