4d4c4cd10b5a8112c4e57c367287dcf89e2b1aa9
[~madcoder/pwqr.git] / kernel / pwqr.c
1 /*
2  * Copyright (C) 2012   Pierre Habouzit <pierre.habouzit@intersec.com>
3  * Copyright (C) 2012   Intersec SAS
4  *
5  * This file implements the Linux Pthread Workqueue Regulator, and is part
6  * of the linux kernel.
7  *
8  * The Linux Kernel is free software: you can redistribute it and/or modify it
9  * under the terms of the GNU General Public License version 2 as published by
10  * the Free Software Foundation.
11  *
12  * The Linux Kernel is distributed in the hope that it will be useful, but
13  * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
14  * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
15  * License for more details.
16  *
17  * You should have received a copy of the GNU General Public License version 2
18  * along with The Linux Kernel.  If not, see <http://www.gnu.org/licenses/>.
19  */
20
21 #include <linux/cdev.h>
22 #include <linux/device.h>
23 #include <linux/file.h>
24 #include <linux/fs.h>
25 #include <linux/hash.h>
26 #include <linux/init.h>
27 #include <linux/kref.h>
28 #include <linux/module.h>
29 #include <linux/sched.h>
30 #include <linux/slab.h>
31 #include <linux/spinlock.h>
32 #include <linux/uaccess.h>
33 #include <linux/wait.h>
34
35 #ifndef CONFIG_PREEMPT_NOTIFIERS
36 #  error PWQ module requires CONFIG_PREEMPT_NOTIFIERS
37 #endif
38
39 #include "pwqr.h"
40
41 #define PWQR_HASH_BITS          5
42 #define PWQR_HASH_SIZE          (1 << PWQR_HASH_BITS)
43
44 struct pwqr_task_bucket {
45         spinlock_t              lock;
46         struct hlist_head       tasks;
47 };
48
49 struct pwqr_sb {
50         struct kref             kref;
51         struct rcu_head         rcu;
52         wait_queue_head_t       wqh;
53         pid_t                   tgid;
54
55         unsigned                concurrency;
56         unsigned                registered;
57
58         unsigned                running;
59         unsigned                waiting;
60         unsigned                quarantined;
61         unsigned                parked;
62         unsigned                overcommit_wakes;
63
64         unsigned                dead;
65 };
66
67 struct pwqr_task {
68         struct preempt_notifier notifier;
69         struct hlist_node       link;
70         struct rcu_head         rcu;
71         struct task_struct     *task;
72         struct pwqr_sb         *sb;
73 };
74
75 /*
76  * Global variables
77  */
78 static struct class            *pwqr_class;
79 static int                      pwqr_major;
80 static struct pwqr_task_bucket  pwqr_tasks_hash[PWQR_HASH_SIZE];
81 static struct preempt_ops       pwqr_preempt_running_ops;
82 static struct preempt_ops       pwqr_preempt_blocked_ops;
83 static struct preempt_ops       pwqr_preempt_noop_ops;
84
85 /*****************************************************************************
86  * Scoreboards
87  */
88
89 #define pwqr_sb_lock_irqsave(sb, flags) \
90         spin_lock_irqsave(&(sb)->wqh.lock, flags)
91 #define pwqr_sb_unlock_irqrestore(sb, flags) \
92         spin_unlock_irqrestore(&(sb)->wqh.lock, flags)
93
94 static inline void __pwqr_sb_update_state(struct pwqr_sb *sb, int running_delta)
95 {
96         int overcommit;
97
98         sb->running += running_delta;
99         overcommit = sb->running + sb->waiting - sb->concurrency;
100         if (overcommit == 0)
101                 return;
102
103         if (overcommit > 0) {
104                 if (overcommit > sb->waiting) {
105                         sb->quarantined += sb->waiting;
106                         sb->waiting      = 0;
107                 } else {
108                         sb->quarantined += overcommit;
109                         sb->waiting     -= overcommit;
110                 }
111         } else {
112                 unsigned undercommit = -overcommit;
113
114                 if (undercommit < sb->quarantined) {
115                         sb->waiting     += undercommit;
116                         sb->quarantined -= undercommit;
117                 } else if (sb->quarantined) {
118                         sb->waiting     += sb->quarantined;
119                         sb->quarantined  = 0;
120                 } else if (sb->waiting == 0 && sb->parked) {
121                         wake_up_locked(&sb->wqh);
122                 }
123         }
124 }
125
126 static struct pwqr_sb *pwqr_sb_create(void)
127 {
128         struct pwqr_sb *sb;
129
130         sb = kzalloc(sizeof(struct pwqr_sb), GFP_KERNEL);
131         if (sb == NULL)
132                 return ERR_PTR(-ENOMEM);
133
134         kref_init(&sb->kref);
135         init_waitqueue_head(&sb->wqh);
136         sb->tgid        = current->tgid;
137         sb->concurrency = num_online_cpus();
138
139         __module_get(THIS_MODULE);
140         return sb;
141 }
142 static inline void pwqr_sb_get(struct pwqr_sb *sb)
143 {
144         kref_get(&sb->kref);
145 }
146
147 static void pwqr_sb_finalize(struct rcu_head *rcu)
148 {
149         struct pwqr_sb *sb = container_of(rcu, struct pwqr_sb, rcu);
150
151         module_put(THIS_MODULE);
152         kfree(sb);
153 }
154
155 static void pwqr_sb_release(struct kref *kref)
156 {
157         struct pwqr_sb *sb = container_of(kref, struct pwqr_sb, kref);
158
159         call_rcu(&sb->rcu, pwqr_sb_finalize);
160 }
161 static inline void pwqr_sb_put(struct pwqr_sb *sb)
162 {
163         kref_put(&sb->kref, pwqr_sb_release);
164 }
165
166 /*****************************************************************************
167  * tasks
168  */
169 static inline struct pwqr_task_bucket *task_hbucket(struct task_struct *task)
170 {
171         return &pwqr_tasks_hash[hash_ptr(task, PWQR_HASH_BITS)];
172 }
173
174 static struct pwqr_task *pwqr_task_find(struct task_struct *task)
175 {
176         struct pwqr_task_bucket *b = task_hbucket(task);
177         struct hlist_node *node;
178         struct pwqr_task *pwqt = NULL;
179
180         spin_lock(&b->lock);
181         hlist_for_each_entry(pwqt, node, &b->tasks, link) {
182                 if (pwqt->task == task)
183                         break;
184         }
185         spin_unlock(&b->lock);
186         return pwqt;
187 }
188
189 static struct pwqr_task *pwqr_task_create(struct task_struct *task)
190 {
191         struct pwqr_task_bucket *b = task_hbucket(task);
192         struct pwqr_task *pwqt;
193
194         pwqt = kmalloc(sizeof(*pwqt), GFP_KERNEL);
195         if (pwqt == NULL)
196                 return ERR_PTR(-ENOMEM);
197
198         preempt_notifier_init(&pwqt->notifier, &pwqr_preempt_running_ops);
199         preempt_notifier_register(&pwqt->notifier);
200         pwqt->task = task;
201
202         spin_lock(&b->lock);
203         hlist_add_head(&pwqt->link, &b->tasks);
204         spin_unlock(&b->lock);
205
206         return pwqt;
207 }
208
209 __cold
210 static void pwqr_task_detach(struct pwqr_task *pwqt, struct pwqr_sb *sb)
211 {
212         unsigned long flags;
213
214         pwqr_sb_lock_irqsave(sb, flags);
215         sb->registered--;
216         if (pwqt->notifier.ops == &pwqr_preempt_running_ops) {
217                 __pwqr_sb_update_state(sb, -1);
218         } else {
219                 __pwqr_sb_update_state(sb, 0);
220         }
221         pwqr_sb_unlock_irqrestore(sb, flags);
222         pwqr_sb_put(sb);
223         pwqt->sb = NULL;
224 }
225
226 __cold
227 static void pwqr_task_attach(struct pwqr_task *pwqt, struct pwqr_sb *sb)
228 {
229         unsigned long flags;
230
231         pwqr_sb_lock_irqsave(sb, flags);
232         pwqr_sb_get(pwqt->sb = sb);
233         sb->registered++;
234         __pwqr_sb_update_state(sb, 1);
235         pwqr_sb_unlock_irqrestore(sb, flags);
236 }
237
238 __cold
239 static void pwqr_task_release(struct pwqr_task *pwqt, bool from_notifier)
240 {
241         struct pwqr_task_bucket *b = task_hbucket(pwqt->task);
242
243         spin_lock(&b->lock);
244         hlist_del(&pwqt->link);
245         spin_unlock(&b->lock);
246         pwqt->notifier.ops = &pwqr_preempt_noop_ops;
247
248         if (from_notifier) {
249                 /* When called from sched_{out,in}, it's not allowed to
250                  * call preempt_notifier_unregister (or worse kfree())
251                  *
252                  * Though it's not a good idea to kfree() still registered
253                  * callbacks if we're not dying, it'll panic on the next
254                  * sched_{in,out} call.
255                  */
256                 BUG_ON(!(pwqt->task->state & TASK_DEAD));
257                 kfree_rcu(pwqt, rcu);
258         } else {
259                 preempt_notifier_unregister(&pwqt->notifier);
260                 kfree(pwqt);
261         }
262 }
263
264 static void pwqr_task_noop_sched_in(struct preempt_notifier *notifier, int cpu)
265 {
266 }
267
268 static void pwqr_task_noop_sched_out(struct preempt_notifier *notifier,
269                                     struct task_struct *next)
270 {
271 }
272
273 static void pwqr_task_blocked_sched_in(struct preempt_notifier *notifier, int cpu)
274 {
275         struct pwqr_task *pwqt = container_of(notifier, struct pwqr_task, notifier);
276         struct pwqr_sb   *sb   = pwqt->sb;
277         unsigned long flags;
278
279         if (unlikely(sb->dead)) {
280                 pwqr_task_detach(pwqt, sb);
281                 pwqr_task_release(pwqt, true);
282                 return;
283         }
284
285         pwqt->notifier.ops = &pwqr_preempt_running_ops;
286         pwqr_sb_lock_irqsave(sb, flags);
287         __pwqr_sb_update_state(sb, 1);
288         pwqr_sb_unlock_irqrestore(sb, flags);
289 }
290
291 static void pwqr_task_sched_out(struct preempt_notifier *notifier,
292                                struct task_struct *next)
293 {
294         struct pwqr_task    *pwqt = container_of(notifier, struct pwqr_task, notifier);
295         struct pwqr_sb      *sb   = pwqt->sb;
296         struct task_struct *p    = pwqt->task;
297
298         if (unlikely(p->state & TASK_DEAD) || unlikely(sb->dead)) {
299                 pwqr_task_detach(pwqt, sb);
300                 pwqr_task_release(pwqt, true);
301                 return;
302         }
303         if (p->state == 0 || (p->state & (__TASK_STOPPED | __TASK_TRACED)))
304                 return;
305
306         pwqt->notifier.ops = &pwqr_preempt_blocked_ops;
307         /* see preempt.h: irq are disabled for sched_out */
308         spin_lock(&sb->wqh.lock);
309         __pwqr_sb_update_state(sb, -1);
310         spin_unlock(&sb->wqh.lock);
311 }
312
313 static struct preempt_ops __read_mostly pwqr_preempt_noop_ops = {
314         .sched_in       = pwqr_task_noop_sched_in,
315         .sched_out      = pwqr_task_noop_sched_out,
316 };
317
318 static struct preempt_ops __read_mostly pwqr_preempt_running_ops = {
319         .sched_in       = pwqr_task_noop_sched_in,
320         .sched_out      = pwqr_task_sched_out,
321 };
322
323 static struct preempt_ops __read_mostly pwqr_preempt_blocked_ops = {
324         .sched_in       = pwqr_task_blocked_sched_in,
325         .sched_out      = pwqr_task_sched_out,
326 };
327
328 /*****************************************************************************
329  * file descriptor
330  */
331 static int pwqr_open(struct inode *inode, struct file *filp)
332 {
333         struct pwqr_sb *sb;
334
335         sb = pwqr_sb_create();
336         if (IS_ERR(sb))
337                 return PTR_ERR(sb);
338         filp->private_data = sb;
339         return 0;
340 }
341
342 static int pwqr_release(struct inode *inode, struct file *filp)
343 {
344         struct pwqr_sb *sb = filp->private_data;
345         unsigned long flags;
346
347         pwqr_sb_lock_irqsave(sb, flags);
348         sb->dead = true;
349         pwqr_sb_unlock_irqrestore(sb, flags);
350         wake_up_all(&sb->wqh);
351         pwqr_sb_put(sb);
352         return 0;
353 }
354
355 static long
356 do_pwqr_wait(struct pwqr_sb *sb, struct pwqr_task *pwqt,
357             int in_pool, struct pwqr_ioc_wait __user *arg)
358 {
359         unsigned long flags;
360         struct pwqr_ioc_wait wait;
361         long rc = 0;
362         u32 uval;
363
364         preempt_notifier_unregister(&pwqt->notifier);
365
366         if (in_pool && copy_from_user(&wait, arg, sizeof(wait))) {
367                 rc = -EFAULT;
368                 goto out;
369         }
370
371         pwqr_sb_lock_irqsave(sb, flags);
372         if (sb->running + sb->waiting <= sb->concurrency) {
373                 if (in_pool) {
374                         while (probe_kernel_address(wait.pwqr_uaddr, uval)) {
375                                 pwqr_sb_unlock_irqrestore(sb, flags);
376                                 rc = get_user(uval, (u32 *)wait.pwqr_uaddr);
377                                 if (rc)
378                                         goto out;
379                                 pwqr_sb_lock_irqsave(sb, flags);
380                         }
381
382                         if (uval != (u32)wait.pwqr_ticket) {
383                                 rc = -EWOULDBLOCK;
384                                 goto out_unlock;
385                         }
386                 } else {
387                         BUG_ON(sb->quarantined != 0);
388                         goto out_unlock;
389                 }
390         }
391
392         /* @ see <wait_event_interruptible_exclusive_locked_irq> */
393         if (likely(!sb->dead)) {
394                 DEFINE_WAIT(__wait);
395
396                 __wait.flags |= WQ_FLAG_EXCLUSIVE;
397
398                 if (in_pool) {
399                         sb->waiting++;
400                         __add_wait_queue(&sb->wqh, &__wait);
401                 } else {
402                         sb->parked++;
403                         __add_wait_queue_tail(&sb->wqh, &__wait);
404                 }
405                 __pwqr_sb_update_state(sb, -1);
406                 set_current_state(TASK_INTERRUPTIBLE);
407
408                 do {
409                         if (sb->overcommit_wakes)
410                                 break;
411                         if (signal_pending(current)) {
412                                 rc = -ERESTARTSYS;
413                                 break;
414                         }
415                         spin_unlock_irq(&sb->wqh.lock);
416                         schedule();
417                         spin_lock_irq(&sb->wqh.lock);
418                         if (in_pool && sb->waiting)
419                                 break;
420                         if (sb->running + sb->waiting < sb->concurrency)
421                                 break;
422                 } while (likely(!sb->dead));
423
424                 __remove_wait_queue(&sb->wqh, &__wait);
425                 __set_current_state(TASK_RUNNING);
426
427                 if (in_pool) {
428                         if (sb->waiting) {
429                                 sb->waiting--;
430                         } else {
431                                 BUG_ON(!sb->quarantined);
432                                 sb->quarantined--;
433                         }
434                 } else {
435                         sb->parked--;
436                 }
437                 __pwqr_sb_update_state(sb, 1);
438                 if (sb->overcommit_wakes)
439                         sb->overcommit_wakes--;
440                 if (sb->waiting + sb->running > sb->concurrency)
441                         rc = -EDQUOT;
442         }
443
444 out_unlock:
445         if (unlikely(sb->dead))
446                 rc = -EBADFD;
447         pwqr_sb_unlock_irqrestore(sb, flags);
448 out:
449         preempt_notifier_register(&pwqt->notifier);
450         return rc;
451 }
452
453 static long do_pwqr_unregister(struct pwqr_sb *sb, struct pwqr_task *pwqt)
454 {
455         if (!pwqt)
456                 return -EINVAL;
457         if (pwqt->sb != sb)
458                 return -ENOENT;
459         pwqr_task_detach(pwqt, sb);
460         pwqr_task_release(pwqt, false);
461         return 0;
462 }
463
464 static long do_pwqr_set_conc(struct pwqr_sb *sb, int conc)
465 {
466         long old_conc = sb->concurrency;
467         unsigned long flags;
468
469         pwqr_sb_lock_irqsave(sb, flags);
470         if (conc <= 0)
471                 conc = num_online_cpus();
472         if (conc != old_conc) {
473                 sb->concurrency = conc;
474                 __pwqr_sb_update_state(sb, 0);
475         }
476         pwqr_sb_unlock_irqrestore(sb, flags);
477
478         return old_conc;
479 }
480
481 static long do_pwqr_wake(struct pwqr_sb *sb, int oc, int count)
482 {
483         unsigned long flags;
484         int nwake;
485
486         if (count < 0)
487                 return -EINVAL;
488
489         pwqr_sb_lock_irqsave(sb, flags);
490
491         if (oc) {
492                 nwake = sb->waiting + sb->quarantined + sb->parked - sb->overcommit_wakes;
493                 if (count > nwake) {
494                         count = nwake;
495                 } else {
496                         nwake = count;
497                 }
498                 sb->overcommit_wakes += count;
499         } else if (sb->running + sb->overcommit_wakes < sb->concurrency) {
500                 nwake = sb->concurrency - sb->overcommit_wakes - sb->running;
501                 if (count > nwake) {
502                         count = nwake;
503                 } else {
504                         nwake = count;
505                 }
506         } else {
507                 /*
508                  * This codepath deserves an explanation: when the thread is
509                  * quarantined, for us, really, it's already "parked". Though
510                  * userland doesn't know about, so wake as many threads as
511                  * userlands would have liked to, and let the wakeup tell
512                  * userland those should be parked.
513                  *
514                  * That's why we lie about the number of woken threads,
515                  * really, userlandwise we woke up a thread so that it could
516                  * be parked for real and avoid spurious syscalls. So it's as
517                  * if we woke up 0 threads.
518                  */
519                 nwake = sb->quarantined;
520                 if (sb->waiting < sb->overcommit_wakes)
521                         nwake -= sb->overcommit_wakes - sb->waiting;
522                 if (nwake > count)
523                         nwake = count;
524                 count = 0;
525         }
526         while (nwake-- > 0)
527                 wake_up_locked(&sb->wqh);
528         pwqr_sb_unlock_irqrestore(sb, flags);
529
530         return count;
531 }
532
533 static long pwqr_ioctl(struct file *filp, unsigned command, unsigned long arg)
534 {
535         struct pwqr_sb      *sb   = filp->private_data;
536         struct task_struct *task = current;
537         struct pwqr_task    *pwqt;
538         int rc = 0;
539
540         if (sb->tgid != current->tgid)
541                 return -EBADFD;
542
543         switch (command) {
544         case PWQR_GET_CONC:
545                 return sb->concurrency;
546         case PWQR_SET_CONC:
547                 return do_pwqr_set_conc(sb, (int)arg);
548
549         case PWQR_WAKE:
550         case PWQR_WAKE_OC:
551                 return do_pwqr_wake(sb, command == PWQR_WAKE_OC, (int)arg);
552
553         case PWQR_WAIT:
554         case PWQR_PARK:
555         case PWQR_REGISTER:
556         case PWQR_UNREGISTER:
557                 break;
558         default:
559                 return -EINVAL;
560         }
561
562         pwqt = pwqr_task_find(task);
563         if (command == PWQR_UNREGISTER)
564                 return do_pwqr_unregister(sb, pwqt);
565
566         if (pwqt == NULL) {
567                 pwqt = pwqr_task_create(task);
568                 if (IS_ERR(pwqt))
569                         return PTR_ERR(pwqt);
570                 pwqr_task_attach(pwqt, sb);
571         } else if (unlikely(pwqt->sb != sb)) {
572                 pwqr_task_detach(pwqt, pwqt->sb);
573                 pwqr_task_attach(pwqt, sb);
574         }
575
576         switch (command) {
577         case PWQR_WAIT:
578                 rc = do_pwqr_wait(sb, pwqt, true, (struct pwqr_ioc_wait __user *)arg);
579                 break;
580         case PWQR_PARK:
581                 rc = do_pwqr_wait(sb, pwqt, false, NULL);
582                 break;
583         }
584
585         if (unlikely(sb->dead)) {
586                 pwqr_task_detach(pwqt, pwqt->sb);
587                 return -EBADFD;
588         }
589         return rc;
590 }
591
592 static const struct file_operations pwqr_dev_fops = {
593         .owner          = THIS_MODULE,
594         .open           = pwqr_open,
595         .release        = pwqr_release,
596         .unlocked_ioctl = pwqr_ioctl,
597 #ifdef CONFIG_COMPAT
598         .compat_ioctl   = pwqr_ioctl,
599 #endif
600 };
601
602 /*****************************************************************************
603  * module
604  */
605 static int __init pwqr_start(void)
606 {
607         int i;
608
609         for (i = 0; i < PWQR_HASH_SIZE; i++) {
610                 spin_lock_init(&pwqr_tasks_hash[i].lock);
611                 INIT_HLIST_HEAD(&pwqr_tasks_hash[i].tasks);
612         }
613
614         /* Register as a character device */
615         pwqr_major = register_chrdev(0, "pwqr", &pwqr_dev_fops);
616         if (pwqr_major < 0) {
617                 printk(KERN_ERR "pwqr: register_chrdev() failed\n");
618                 return pwqr_major;
619         }
620
621         /* Create a device node */
622         pwqr_class = class_create(THIS_MODULE, PWQR_DEVICE_NAME);
623         if (IS_ERR(pwqr_class)) {
624                 printk(KERN_ERR "pwqr: Error creating raw class\n");
625                 unregister_chrdev(pwqr_major, PWQR_DEVICE_NAME);
626                 return PTR_ERR(pwqr_class);
627         }
628         device_create(pwqr_class, NULL, MKDEV(pwqr_major, 0), NULL, PWQR_DEVICE_NAME);
629         printk(KERN_INFO "pwqr: PThreads Work Queues Regulator v1 loaded");
630         return 0;
631 }
632
633 static void __exit pwqr_end(void)
634 {
635         rcu_barrier();
636         device_destroy(pwqr_class, MKDEV(pwqr_major, 0));
637         class_destroy(pwqr_class);
638         unregister_chrdev(pwqr_major, PWQR_DEVICE_NAME);
639 }
640
641 module_init(pwqr_start);
642 module_exit(pwqr_end);
643
644 MODULE_LICENSE("GPL");
645 MODULE_AUTHOR("Pierre Habouzit <pierre.habouzit@intersec.com>");
646 MODULE_DESCRIPTION("PThreads Work Queues Regulator");
647
648 // vim:noet:sw=8:cinoptions+=\:0,L-1,=1s: