iommu/amd: Use pr_fmt()
[platform/kernel/linux-rpi.git] / drivers / iommu / amd_iommu_v2.c
1 /*
2  * Copyright (C) 2010-2012 Advanced Micro Devices, Inc.
3  * Author: Joerg Roedel <jroedel@suse.de>
4  *
5  * This program is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 as published
7  * by the Free Software Foundation.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program; if not, write to the Free Software
16  * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307 USA
17  */
18
19 #define pr_fmt(fmt)     "AMD-Vi: " fmt
20
21 #include <linux/mmu_notifier.h>
22 #include <linux/amd-iommu.h>
23 #include <linux/mm_types.h>
24 #include <linux/profile.h>
25 #include <linux/module.h>
26 #include <linux/sched.h>
27 #include <linux/sched/mm.h>
28 #include <linux/iommu.h>
29 #include <linux/wait.h>
30 #include <linux/pci.h>
31 #include <linux/gfp.h>
32
33 #include "amd_iommu_types.h"
34 #include "amd_iommu_proto.h"
35
36 MODULE_LICENSE("GPL v2");
37 MODULE_AUTHOR("Joerg Roedel <jroedel@suse.de>");
38
39 #define MAX_DEVICES             0x10000
40 #define PRI_QUEUE_SIZE          512
41
42 struct pri_queue {
43         atomic_t inflight;
44         bool finish;
45         int status;
46 };
47
48 struct pasid_state {
49         struct list_head list;                  /* For global state-list */
50         atomic_t count;                         /* Reference count */
51         unsigned mmu_notifier_count;            /* Counting nested mmu_notifier
52                                                    calls */
53         struct mm_struct *mm;                   /* mm_struct for the faults */
54         struct mmu_notifier mn;                 /* mmu_notifier handle */
55         struct pri_queue pri[PRI_QUEUE_SIZE];   /* PRI tag states */
56         struct device_state *device_state;      /* Link to our device_state */
57         int pasid;                              /* PASID index */
58         bool invalid;                           /* Used during setup and
59                                                    teardown of the pasid */
60         spinlock_t lock;                        /* Protect pri_queues and
61                                                    mmu_notifer_count */
62         wait_queue_head_t wq;                   /* To wait for count == 0 */
63 };
64
65 struct device_state {
66         struct list_head list;
67         u16 devid;
68         atomic_t count;
69         struct pci_dev *pdev;
70         struct pasid_state **states;
71         struct iommu_domain *domain;
72         int pasid_levels;
73         int max_pasids;
74         amd_iommu_invalid_ppr_cb inv_ppr_cb;
75         amd_iommu_invalidate_ctx inv_ctx_cb;
76         spinlock_t lock;
77         wait_queue_head_t wq;
78 };
79
80 struct fault {
81         struct work_struct work;
82         struct device_state *dev_state;
83         struct pasid_state *state;
84         struct mm_struct *mm;
85         u64 address;
86         u16 devid;
87         u16 pasid;
88         u16 tag;
89         u16 finish;
90         u16 flags;
91 };
92
93 static LIST_HEAD(state_list);
94 static spinlock_t state_lock;
95
96 static struct workqueue_struct *iommu_wq;
97
98 static void free_pasid_states(struct device_state *dev_state);
99
100 static u16 device_id(struct pci_dev *pdev)
101 {
102         u16 devid;
103
104         devid = pdev->bus->number;
105         devid = (devid << 8) | pdev->devfn;
106
107         return devid;
108 }
109
110 static struct device_state *__get_device_state(u16 devid)
111 {
112         struct device_state *dev_state;
113
114         list_for_each_entry(dev_state, &state_list, list) {
115                 if (dev_state->devid == devid)
116                         return dev_state;
117         }
118
119         return NULL;
120 }
121
122 static struct device_state *get_device_state(u16 devid)
123 {
124         struct device_state *dev_state;
125         unsigned long flags;
126
127         spin_lock_irqsave(&state_lock, flags);
128         dev_state = __get_device_state(devid);
129         if (dev_state != NULL)
130                 atomic_inc(&dev_state->count);
131         spin_unlock_irqrestore(&state_lock, flags);
132
133         return dev_state;
134 }
135
136 static void free_device_state(struct device_state *dev_state)
137 {
138         struct iommu_group *group;
139
140         /*
141          * First detach device from domain - No more PRI requests will arrive
142          * from that device after it is unbound from the IOMMUv2 domain.
143          */
144         group = iommu_group_get(&dev_state->pdev->dev);
145         if (WARN_ON(!group))
146                 return;
147
148         iommu_detach_group(dev_state->domain, group);
149
150         iommu_group_put(group);
151
152         /* Everything is down now, free the IOMMUv2 domain */
153         iommu_domain_free(dev_state->domain);
154
155         /* Finally get rid of the device-state */
156         kfree(dev_state);
157 }
158
159 static void put_device_state(struct device_state *dev_state)
160 {
161         if (atomic_dec_and_test(&dev_state->count))
162                 wake_up(&dev_state->wq);
163 }
164
165 /* Must be called under dev_state->lock */
166 static struct pasid_state **__get_pasid_state_ptr(struct device_state *dev_state,
167                                                   int pasid, bool alloc)
168 {
169         struct pasid_state **root, **ptr;
170         int level, index;
171
172         level = dev_state->pasid_levels;
173         root  = dev_state->states;
174
175         while (true) {
176
177                 index = (pasid >> (9 * level)) & 0x1ff;
178                 ptr   = &root[index];
179
180                 if (level == 0)
181                         break;
182
183                 if (*ptr == NULL) {
184                         if (!alloc)
185                                 return NULL;
186
187                         *ptr = (void *)get_zeroed_page(GFP_ATOMIC);
188                         if (*ptr == NULL)
189                                 return NULL;
190                 }
191
192                 root   = (struct pasid_state **)*ptr;
193                 level -= 1;
194         }
195
196         return ptr;
197 }
198
199 static int set_pasid_state(struct device_state *dev_state,
200                            struct pasid_state *pasid_state,
201                            int pasid)
202 {
203         struct pasid_state **ptr;
204         unsigned long flags;
205         int ret;
206
207         spin_lock_irqsave(&dev_state->lock, flags);
208         ptr = __get_pasid_state_ptr(dev_state, pasid, true);
209
210         ret = -ENOMEM;
211         if (ptr == NULL)
212                 goto out_unlock;
213
214         ret = -ENOMEM;
215         if (*ptr != NULL)
216                 goto out_unlock;
217
218         *ptr = pasid_state;
219
220         ret = 0;
221
222 out_unlock:
223         spin_unlock_irqrestore(&dev_state->lock, flags);
224
225         return ret;
226 }
227
228 static void clear_pasid_state(struct device_state *dev_state, int pasid)
229 {
230         struct pasid_state **ptr;
231         unsigned long flags;
232
233         spin_lock_irqsave(&dev_state->lock, flags);
234         ptr = __get_pasid_state_ptr(dev_state, pasid, true);
235
236         if (ptr == NULL)
237                 goto out_unlock;
238
239         *ptr = NULL;
240
241 out_unlock:
242         spin_unlock_irqrestore(&dev_state->lock, flags);
243 }
244
245 static struct pasid_state *get_pasid_state(struct device_state *dev_state,
246                                            int pasid)
247 {
248         struct pasid_state **ptr, *ret = NULL;
249         unsigned long flags;
250
251         spin_lock_irqsave(&dev_state->lock, flags);
252         ptr = __get_pasid_state_ptr(dev_state, pasid, false);
253
254         if (ptr == NULL)
255                 goto out_unlock;
256
257         ret = *ptr;
258         if (ret)
259                 atomic_inc(&ret->count);
260
261 out_unlock:
262         spin_unlock_irqrestore(&dev_state->lock, flags);
263
264         return ret;
265 }
266
267 static void free_pasid_state(struct pasid_state *pasid_state)
268 {
269         kfree(pasid_state);
270 }
271
272 static void put_pasid_state(struct pasid_state *pasid_state)
273 {
274         if (atomic_dec_and_test(&pasid_state->count))
275                 wake_up(&pasid_state->wq);
276 }
277
278 static void put_pasid_state_wait(struct pasid_state *pasid_state)
279 {
280         atomic_dec(&pasid_state->count);
281         wait_event(pasid_state->wq, !atomic_read(&pasid_state->count));
282         free_pasid_state(pasid_state);
283 }
284
285 static void unbind_pasid(struct pasid_state *pasid_state)
286 {
287         struct iommu_domain *domain;
288
289         domain = pasid_state->device_state->domain;
290
291         /*
292          * Mark pasid_state as invalid, no more faults will we added to the
293          * work queue after this is visible everywhere.
294          */
295         pasid_state->invalid = true;
296
297         /* Make sure this is visible */
298         smp_wmb();
299
300         /* After this the device/pasid can't access the mm anymore */
301         amd_iommu_domain_clear_gcr3(domain, pasid_state->pasid);
302
303         /* Make sure no more pending faults are in the queue */
304         flush_workqueue(iommu_wq);
305 }
306
307 static void free_pasid_states_level1(struct pasid_state **tbl)
308 {
309         int i;
310
311         for (i = 0; i < 512; ++i) {
312                 if (tbl[i] == NULL)
313                         continue;
314
315                 free_page((unsigned long)tbl[i]);
316         }
317 }
318
319 static void free_pasid_states_level2(struct pasid_state **tbl)
320 {
321         struct pasid_state **ptr;
322         int i;
323
324         for (i = 0; i < 512; ++i) {
325                 if (tbl[i] == NULL)
326                         continue;
327
328                 ptr = (struct pasid_state **)tbl[i];
329                 free_pasid_states_level1(ptr);
330         }
331 }
332
333 static void free_pasid_states(struct device_state *dev_state)
334 {
335         struct pasid_state *pasid_state;
336         int i;
337
338         for (i = 0; i < dev_state->max_pasids; ++i) {
339                 pasid_state = get_pasid_state(dev_state, i);
340                 if (pasid_state == NULL)
341                         continue;
342
343                 put_pasid_state(pasid_state);
344
345                 /*
346                  * This will call the mn_release function and
347                  * unbind the PASID
348                  */
349                 mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm);
350
351                 put_pasid_state_wait(pasid_state); /* Reference taken in
352                                                       amd_iommu_bind_pasid */
353
354                 /* Drop reference taken in amd_iommu_bind_pasid */
355                 put_device_state(dev_state);
356         }
357
358         if (dev_state->pasid_levels == 2)
359                 free_pasid_states_level2(dev_state->states);
360         else if (dev_state->pasid_levels == 1)
361                 free_pasid_states_level1(dev_state->states);
362         else
363                 BUG_ON(dev_state->pasid_levels != 0);
364
365         free_page((unsigned long)dev_state->states);
366 }
367
368 static struct pasid_state *mn_to_state(struct mmu_notifier *mn)
369 {
370         return container_of(mn, struct pasid_state, mn);
371 }
372
373 static void __mn_flush_page(struct mmu_notifier *mn,
374                             unsigned long address)
375 {
376         struct pasid_state *pasid_state;
377         struct device_state *dev_state;
378
379         pasid_state = mn_to_state(mn);
380         dev_state   = pasid_state->device_state;
381
382         amd_iommu_flush_page(dev_state->domain, pasid_state->pasid, address);
383 }
384
385 static int mn_clear_flush_young(struct mmu_notifier *mn,
386                                 struct mm_struct *mm,
387                                 unsigned long start,
388                                 unsigned long end)
389 {
390         for (; start < end; start += PAGE_SIZE)
391                 __mn_flush_page(mn, start);
392
393         return 0;
394 }
395
396 static void mn_invalidate_range(struct mmu_notifier *mn,
397                                 struct mm_struct *mm,
398                                 unsigned long start, unsigned long end)
399 {
400         struct pasid_state *pasid_state;
401         struct device_state *dev_state;
402
403         pasid_state = mn_to_state(mn);
404         dev_state   = pasid_state->device_state;
405
406         if ((start ^ (end - 1)) < PAGE_SIZE)
407                 amd_iommu_flush_page(dev_state->domain, pasid_state->pasid,
408                                      start);
409         else
410                 amd_iommu_flush_tlb(dev_state->domain, pasid_state->pasid);
411 }
412
413 static void mn_release(struct mmu_notifier *mn, struct mm_struct *mm)
414 {
415         struct pasid_state *pasid_state;
416         struct device_state *dev_state;
417         bool run_inv_ctx_cb;
418
419         might_sleep();
420
421         pasid_state    = mn_to_state(mn);
422         dev_state      = pasid_state->device_state;
423         run_inv_ctx_cb = !pasid_state->invalid;
424
425         if (run_inv_ctx_cb && dev_state->inv_ctx_cb)
426                 dev_state->inv_ctx_cb(dev_state->pdev, pasid_state->pasid);
427
428         unbind_pasid(pasid_state);
429 }
430
431 static const struct mmu_notifier_ops iommu_mn = {
432         .release                = mn_release,
433         .clear_flush_young      = mn_clear_flush_young,
434         .invalidate_range       = mn_invalidate_range,
435 };
436
437 static void set_pri_tag_status(struct pasid_state *pasid_state,
438                                u16 tag, int status)
439 {
440         unsigned long flags;
441
442         spin_lock_irqsave(&pasid_state->lock, flags);
443         pasid_state->pri[tag].status = status;
444         spin_unlock_irqrestore(&pasid_state->lock, flags);
445 }
446
447 static void finish_pri_tag(struct device_state *dev_state,
448                            struct pasid_state *pasid_state,
449                            u16 tag)
450 {
451         unsigned long flags;
452
453         spin_lock_irqsave(&pasid_state->lock, flags);
454         if (atomic_dec_and_test(&pasid_state->pri[tag].inflight) &&
455             pasid_state->pri[tag].finish) {
456                 amd_iommu_complete_ppr(dev_state->pdev, pasid_state->pasid,
457                                        pasid_state->pri[tag].status, tag);
458                 pasid_state->pri[tag].finish = false;
459                 pasid_state->pri[tag].status = PPR_SUCCESS;
460         }
461         spin_unlock_irqrestore(&pasid_state->lock, flags);
462 }
463
464 static void handle_fault_error(struct fault *fault)
465 {
466         int status;
467
468         if (!fault->dev_state->inv_ppr_cb) {
469                 set_pri_tag_status(fault->state, fault->tag, PPR_INVALID);
470                 return;
471         }
472
473         status = fault->dev_state->inv_ppr_cb(fault->dev_state->pdev,
474                                               fault->pasid,
475                                               fault->address,
476                                               fault->flags);
477         switch (status) {
478         case AMD_IOMMU_INV_PRI_RSP_SUCCESS:
479                 set_pri_tag_status(fault->state, fault->tag, PPR_SUCCESS);
480                 break;
481         case AMD_IOMMU_INV_PRI_RSP_INVALID:
482                 set_pri_tag_status(fault->state, fault->tag, PPR_INVALID);
483                 break;
484         case AMD_IOMMU_INV_PRI_RSP_FAIL:
485                 set_pri_tag_status(fault->state, fault->tag, PPR_FAILURE);
486                 break;
487         default:
488                 BUG();
489         }
490 }
491
492 static bool access_error(struct vm_area_struct *vma, struct fault *fault)
493 {
494         unsigned long requested = 0;
495
496         if (fault->flags & PPR_FAULT_EXEC)
497                 requested |= VM_EXEC;
498
499         if (fault->flags & PPR_FAULT_READ)
500                 requested |= VM_READ;
501
502         if (fault->flags & PPR_FAULT_WRITE)
503                 requested |= VM_WRITE;
504
505         return (requested & ~vma->vm_flags) != 0;
506 }
507
508 static void do_fault(struct work_struct *work)
509 {
510         struct fault *fault = container_of(work, struct fault, work);
511         struct vm_area_struct *vma;
512         vm_fault_t ret = VM_FAULT_ERROR;
513         unsigned int flags = 0;
514         struct mm_struct *mm;
515         u64 address;
516
517         mm = fault->state->mm;
518         address = fault->address;
519
520         if (fault->flags & PPR_FAULT_USER)
521                 flags |= FAULT_FLAG_USER;
522         if (fault->flags & PPR_FAULT_WRITE)
523                 flags |= FAULT_FLAG_WRITE;
524         flags |= FAULT_FLAG_REMOTE;
525
526         down_read(&mm->mmap_sem);
527         vma = find_extend_vma(mm, address);
528         if (!vma || address < vma->vm_start)
529                 /* failed to get a vma in the right range */
530                 goto out;
531
532         /* Check if we have the right permissions on the vma */
533         if (access_error(vma, fault))
534                 goto out;
535
536         ret = handle_mm_fault(vma, address, flags);
537 out:
538         up_read(&mm->mmap_sem);
539
540         if (ret & VM_FAULT_ERROR)
541                 /* failed to service fault */
542                 handle_fault_error(fault);
543
544         finish_pri_tag(fault->dev_state, fault->state, fault->tag);
545
546         put_pasid_state(fault->state);
547
548         kfree(fault);
549 }
550
551 static int ppr_notifier(struct notifier_block *nb, unsigned long e, void *data)
552 {
553         struct amd_iommu_fault *iommu_fault;
554         struct pasid_state *pasid_state;
555         struct device_state *dev_state;
556         unsigned long flags;
557         struct fault *fault;
558         bool finish;
559         u16 tag, devid;
560         int ret;
561         struct iommu_dev_data *dev_data;
562         struct pci_dev *pdev = NULL;
563
564         iommu_fault = data;
565         tag         = iommu_fault->tag & 0x1ff;
566         finish      = (iommu_fault->tag >> 9) & 1;
567
568         devid = iommu_fault->device_id;
569         pdev = pci_get_domain_bus_and_slot(0, PCI_BUS_NUM(devid),
570                                            devid & 0xff);
571         if (!pdev)
572                 return -ENODEV;
573         dev_data = get_dev_data(&pdev->dev);
574
575         /* In kdump kernel pci dev is not initialized yet -> send INVALID */
576         ret = NOTIFY_DONE;
577         if (translation_pre_enabled(amd_iommu_rlookup_table[devid])
578                 && dev_data->defer_attach) {
579                 amd_iommu_complete_ppr(pdev, iommu_fault->pasid,
580                                        PPR_INVALID, tag);
581                 goto out;
582         }
583
584         dev_state = get_device_state(iommu_fault->device_id);
585         if (dev_state == NULL)
586                 goto out;
587
588         pasid_state = get_pasid_state(dev_state, iommu_fault->pasid);
589         if (pasid_state == NULL || pasid_state->invalid) {
590                 /* We know the device but not the PASID -> send INVALID */
591                 amd_iommu_complete_ppr(dev_state->pdev, iommu_fault->pasid,
592                                        PPR_INVALID, tag);
593                 goto out_drop_state;
594         }
595
596         spin_lock_irqsave(&pasid_state->lock, flags);
597         atomic_inc(&pasid_state->pri[tag].inflight);
598         if (finish)
599                 pasid_state->pri[tag].finish = true;
600         spin_unlock_irqrestore(&pasid_state->lock, flags);
601
602         fault = kzalloc(sizeof(*fault), GFP_ATOMIC);
603         if (fault == NULL) {
604                 /* We are OOM - send success and let the device re-fault */
605                 finish_pri_tag(dev_state, pasid_state, tag);
606                 goto out_drop_state;
607         }
608
609         fault->dev_state = dev_state;
610         fault->address   = iommu_fault->address;
611         fault->state     = pasid_state;
612         fault->tag       = tag;
613         fault->finish    = finish;
614         fault->pasid     = iommu_fault->pasid;
615         fault->flags     = iommu_fault->flags;
616         INIT_WORK(&fault->work, do_fault);
617
618         queue_work(iommu_wq, &fault->work);
619
620         ret = NOTIFY_OK;
621
622 out_drop_state:
623
624         if (ret != NOTIFY_OK && pasid_state)
625                 put_pasid_state(pasid_state);
626
627         put_device_state(dev_state);
628
629 out:
630         return ret;
631 }
632
633 static struct notifier_block ppr_nb = {
634         .notifier_call = ppr_notifier,
635 };
636
637 int amd_iommu_bind_pasid(struct pci_dev *pdev, int pasid,
638                          struct task_struct *task)
639 {
640         struct pasid_state *pasid_state;
641         struct device_state *dev_state;
642         struct mm_struct *mm;
643         u16 devid;
644         int ret;
645
646         might_sleep();
647
648         if (!amd_iommu_v2_supported())
649                 return -ENODEV;
650
651         devid     = device_id(pdev);
652         dev_state = get_device_state(devid);
653
654         if (dev_state == NULL)
655                 return -EINVAL;
656
657         ret = -EINVAL;
658         if (pasid < 0 || pasid >= dev_state->max_pasids)
659                 goto out;
660
661         ret = -ENOMEM;
662         pasid_state = kzalloc(sizeof(*pasid_state), GFP_KERNEL);
663         if (pasid_state == NULL)
664                 goto out;
665
666
667         atomic_set(&pasid_state->count, 1);
668         init_waitqueue_head(&pasid_state->wq);
669         spin_lock_init(&pasid_state->lock);
670
671         mm                        = get_task_mm(task);
672         pasid_state->mm           = mm;
673         pasid_state->device_state = dev_state;
674         pasid_state->pasid        = pasid;
675         pasid_state->invalid      = true; /* Mark as valid only if we are
676                                              done with setting up the pasid */
677         pasid_state->mn.ops       = &iommu_mn;
678
679         if (pasid_state->mm == NULL)
680                 goto out_free;
681
682         mmu_notifier_register(&pasid_state->mn, mm);
683
684         ret = set_pasid_state(dev_state, pasid_state, pasid);
685         if (ret)
686                 goto out_unregister;
687
688         ret = amd_iommu_domain_set_gcr3(dev_state->domain, pasid,
689                                         __pa(pasid_state->mm->pgd));
690         if (ret)
691                 goto out_clear_state;
692
693         /* Now we are ready to handle faults */
694         pasid_state->invalid = false;
695
696         /*
697          * Drop the reference to the mm_struct here. We rely on the
698          * mmu_notifier release call-back to inform us when the mm
699          * is going away.
700          */
701         mmput(mm);
702
703         return 0;
704
705 out_clear_state:
706         clear_pasid_state(dev_state, pasid);
707
708 out_unregister:
709         mmu_notifier_unregister(&pasid_state->mn, mm);
710         mmput(mm);
711
712 out_free:
713         free_pasid_state(pasid_state);
714
715 out:
716         put_device_state(dev_state);
717
718         return ret;
719 }
720 EXPORT_SYMBOL(amd_iommu_bind_pasid);
721
722 void amd_iommu_unbind_pasid(struct pci_dev *pdev, int pasid)
723 {
724         struct pasid_state *pasid_state;
725         struct device_state *dev_state;
726         u16 devid;
727
728         might_sleep();
729
730         if (!amd_iommu_v2_supported())
731                 return;
732
733         devid = device_id(pdev);
734         dev_state = get_device_state(devid);
735         if (dev_state == NULL)
736                 return;
737
738         if (pasid < 0 || pasid >= dev_state->max_pasids)
739                 goto out;
740
741         pasid_state = get_pasid_state(dev_state, pasid);
742         if (pasid_state == NULL)
743                 goto out;
744         /*
745          * Drop reference taken here. We are safe because we still hold
746          * the reference taken in the amd_iommu_bind_pasid function.
747          */
748         put_pasid_state(pasid_state);
749
750         /* Clear the pasid state so that the pasid can be re-used */
751         clear_pasid_state(dev_state, pasid_state->pasid);
752
753         /*
754          * Call mmu_notifier_unregister to drop our reference
755          * to pasid_state->mm
756          */
757         mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm);
758
759         put_pasid_state_wait(pasid_state); /* Reference taken in
760                                               amd_iommu_bind_pasid */
761 out:
762         /* Drop reference taken in this function */
763         put_device_state(dev_state);
764
765         /* Drop reference taken in amd_iommu_bind_pasid */
766         put_device_state(dev_state);
767 }
768 EXPORT_SYMBOL(amd_iommu_unbind_pasid);
769
770 int amd_iommu_init_device(struct pci_dev *pdev, int pasids)
771 {
772         struct device_state *dev_state;
773         struct iommu_group *group;
774         unsigned long flags;
775         int ret, tmp;
776         u16 devid;
777
778         might_sleep();
779
780         if (!amd_iommu_v2_supported())
781                 return -ENODEV;
782
783         if (pasids <= 0 || pasids > (PASID_MASK + 1))
784                 return -EINVAL;
785
786         devid = device_id(pdev);
787
788         dev_state = kzalloc(sizeof(*dev_state), GFP_KERNEL);
789         if (dev_state == NULL)
790                 return -ENOMEM;
791
792         spin_lock_init(&dev_state->lock);
793         init_waitqueue_head(&dev_state->wq);
794         dev_state->pdev  = pdev;
795         dev_state->devid = devid;
796
797         tmp = pasids;
798         for (dev_state->pasid_levels = 0; (tmp - 1) & ~0x1ff; tmp >>= 9)
799                 dev_state->pasid_levels += 1;
800
801         atomic_set(&dev_state->count, 1);
802         dev_state->max_pasids = pasids;
803
804         ret = -ENOMEM;
805         dev_state->states = (void *)get_zeroed_page(GFP_KERNEL);
806         if (dev_state->states == NULL)
807                 goto out_free_dev_state;
808
809         dev_state->domain = iommu_domain_alloc(&pci_bus_type);
810         if (dev_state->domain == NULL)
811                 goto out_free_states;
812
813         amd_iommu_domain_direct_map(dev_state->domain);
814
815         ret = amd_iommu_domain_enable_v2(dev_state->domain, pasids);
816         if (ret)
817                 goto out_free_domain;
818
819         group = iommu_group_get(&pdev->dev);
820         if (!group) {
821                 ret = -EINVAL;
822                 goto out_free_domain;
823         }
824
825         ret = iommu_attach_group(dev_state->domain, group);
826         if (ret != 0)
827                 goto out_drop_group;
828
829         iommu_group_put(group);
830
831         spin_lock_irqsave(&state_lock, flags);
832
833         if (__get_device_state(devid) != NULL) {
834                 spin_unlock_irqrestore(&state_lock, flags);
835                 ret = -EBUSY;
836                 goto out_free_domain;
837         }
838
839         list_add_tail(&dev_state->list, &state_list);
840
841         spin_unlock_irqrestore(&state_lock, flags);
842
843         return 0;
844
845 out_drop_group:
846         iommu_group_put(group);
847
848 out_free_domain:
849         iommu_domain_free(dev_state->domain);
850
851 out_free_states:
852         free_page((unsigned long)dev_state->states);
853
854 out_free_dev_state:
855         kfree(dev_state);
856
857         return ret;
858 }
859 EXPORT_SYMBOL(amd_iommu_init_device);
860
861 void amd_iommu_free_device(struct pci_dev *pdev)
862 {
863         struct device_state *dev_state;
864         unsigned long flags;
865         u16 devid;
866
867         if (!amd_iommu_v2_supported())
868                 return;
869
870         devid = device_id(pdev);
871
872         spin_lock_irqsave(&state_lock, flags);
873
874         dev_state = __get_device_state(devid);
875         if (dev_state == NULL) {
876                 spin_unlock_irqrestore(&state_lock, flags);
877                 return;
878         }
879
880         list_del(&dev_state->list);
881
882         spin_unlock_irqrestore(&state_lock, flags);
883
884         /* Get rid of any remaining pasid states */
885         free_pasid_states(dev_state);
886
887         put_device_state(dev_state);
888         /*
889          * Wait until the last reference is dropped before freeing
890          * the device state.
891          */
892         wait_event(dev_state->wq, !atomic_read(&dev_state->count));
893         free_device_state(dev_state);
894 }
895 EXPORT_SYMBOL(amd_iommu_free_device);
896
897 int amd_iommu_set_invalid_ppr_cb(struct pci_dev *pdev,
898                                  amd_iommu_invalid_ppr_cb cb)
899 {
900         struct device_state *dev_state;
901         unsigned long flags;
902         u16 devid;
903         int ret;
904
905         if (!amd_iommu_v2_supported())
906                 return -ENODEV;
907
908         devid = device_id(pdev);
909
910         spin_lock_irqsave(&state_lock, flags);
911
912         ret = -EINVAL;
913         dev_state = __get_device_state(devid);
914         if (dev_state == NULL)
915                 goto out_unlock;
916
917         dev_state->inv_ppr_cb = cb;
918
919         ret = 0;
920
921 out_unlock:
922         spin_unlock_irqrestore(&state_lock, flags);
923
924         return ret;
925 }
926 EXPORT_SYMBOL(amd_iommu_set_invalid_ppr_cb);
927
928 int amd_iommu_set_invalidate_ctx_cb(struct pci_dev *pdev,
929                                     amd_iommu_invalidate_ctx cb)
930 {
931         struct device_state *dev_state;
932         unsigned long flags;
933         u16 devid;
934         int ret;
935
936         if (!amd_iommu_v2_supported())
937                 return -ENODEV;
938
939         devid = device_id(pdev);
940
941         spin_lock_irqsave(&state_lock, flags);
942
943         ret = -EINVAL;
944         dev_state = __get_device_state(devid);
945         if (dev_state == NULL)
946                 goto out_unlock;
947
948         dev_state->inv_ctx_cb = cb;
949
950         ret = 0;
951
952 out_unlock:
953         spin_unlock_irqrestore(&state_lock, flags);
954
955         return ret;
956 }
957 EXPORT_SYMBOL(amd_iommu_set_invalidate_ctx_cb);
958
959 static int __init amd_iommu_v2_init(void)
960 {
961         int ret;
962
963         pr_info("AMD IOMMUv2 driver by Joerg Roedel <jroedel@suse.de>\n");
964
965         if (!amd_iommu_v2_supported()) {
966                 pr_info("AMD IOMMUv2 functionality not available on this system\n");
967                 /*
968                  * Load anyway to provide the symbols to other modules
969                  * which may use AMD IOMMUv2 optionally.
970                  */
971                 return 0;
972         }
973
974         spin_lock_init(&state_lock);
975
976         ret = -ENOMEM;
977         iommu_wq = alloc_workqueue("amd_iommu_v2", WQ_MEM_RECLAIM, 0);
978         if (iommu_wq == NULL)
979                 goto out;
980
981         amd_iommu_register_ppr_notifier(&ppr_nb);
982
983         return 0;
984
985 out:
986         return ret;
987 }
988
989 static void __exit amd_iommu_v2_exit(void)
990 {
991         struct device_state *dev_state;
992         int i;
993
994         if (!amd_iommu_v2_supported())
995                 return;
996
997         amd_iommu_unregister_ppr_notifier(&ppr_nb);
998
999         flush_workqueue(iommu_wq);
1000
1001         /*
1002          * The loop below might call flush_workqueue(), so call
1003          * destroy_workqueue() after it
1004          */
1005         for (i = 0; i < MAX_DEVICES; ++i) {
1006                 dev_state = get_device_state(i);
1007
1008                 if (dev_state == NULL)
1009                         continue;
1010
1011                 WARN_ON_ONCE(1);
1012
1013                 put_device_state(dev_state);
1014                 amd_iommu_free_device(dev_state->pdev);
1015         }
1016
1017         destroy_workqueue(iommu_wq);
1018 }
1019
1020 module_init(amd_iommu_v2_init);
1021 module_exit(amd_iommu_v2_exit);