e4e3f58a872b146af038b8cffc9a5c6f2ca5ab0a
[~madcoder/pwqr.git] / kernel / pwqr.c
1 /*
2  * Copyright (C) 2012   Pierre Habouzit <pierre.habouzit@intersec.com>
3  * Copyright (C) 2012   Intersec SAS
4  *
5  * This file implements the Linux Pthread Workqueue Regulator, and is part
6  * of the linux kernel.
7  *
8  * The Linux Kernel is free software: you can redistribute it and/or modify it
9  * under the terms of the GNU General Public License version 2 as published by
10  * the Free Software Foundation.
11  *
12  * The Linux Kernel is distributed in the hope that it will be useful, but
13  * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
14  * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
15  * License for more details.
16  *
17  * You should have received a copy of the GNU General Public License version 2
18  * along with The Linux Kernel.  If not, see <http://www.gnu.org/licenses/>.
19  */
20
21 #include <linux/cdev.h>
22 #include <linux/device.h>
23 #include <linux/file.h>
24 #include <linux/fs.h>
25 #include <linux/hash.h>
26 #include <linux/init.h>
27 #include <linux/kref.h>
28 #include <linux/module.h>
29 #include <linux/sched.h>
30 #include <linux/slab.h>
31 #include <linux/spinlock.h>
32 #include <linux/timer.h>
33 #include <linux/uaccess.h>
34 #include <linux/wait.h>
35
36 #ifndef CONFIG_PREEMPT_NOTIFIERS
37 #  error PWQ module requires CONFIG_PREEMPT_NOTIFIERS
38 #endif
39
40 #include "pwqr.h"
41
42 #define PWQR_UNPARK_DELAY       (HZ / 10)
43 #define PWQR_HASH_BITS          5
44 #define PWQR_HASH_SIZE          (1 << PWQR_HASH_BITS)
45
46 struct pwqr_task_bucket {
47         spinlock_t              lock;
48         struct hlist_head       tasks;
49 };
50
51 struct pwqr_sb {
52         struct kref             kref;
53         struct rcu_head         rcu;
54         struct timer_list       timer;
55         wait_queue_head_t       wqh;
56         pid_t                   tgid;
57
58         unsigned                concurrency;
59         unsigned                registered;
60
61         unsigned                running;
62         unsigned                waiting;
63         unsigned                quarantined;
64         unsigned                parked;
65         unsigned                overcommit_wakes;
66
67         unsigned                dead;
68 };
69
70 struct pwqr_task {
71         struct preempt_notifier notifier;
72         struct hlist_node       link;
73         struct rcu_head         rcu;
74         struct task_struct     *task;
75         struct pwqr_sb         *sb;
76 };
77
78 /*
79  * Global variables
80  */
81 static struct class            *pwqr_class;
82 static int                      pwqr_major;
83 static struct pwqr_task_bucket  pwqr_tasks_hash[PWQR_HASH_SIZE];
84 static struct preempt_ops       pwqr_preempt_running_ops;
85 static struct preempt_ops       pwqr_preempt_blocked_ops;
86 static struct preempt_ops       pwqr_preempt_noop_ops;
87
88 /*****************************************************************************
89  * Scoreboards
90  */
91
92 #define pwqr_sb_lock_irqsave(sb, flags) \
93         spin_lock_irqsave(&(sb)->wqh.lock, flags)
94 #define pwqr_sb_unlock_irqrestore(sb, flags) \
95         spin_unlock_irqrestore(&(sb)->wqh.lock, flags)
96
97 static inline void __pwqr_sb_update_state(struct pwqr_sb *sb, int running_delta)
98 {
99         int overcommit;
100
101         sb->running += running_delta;
102         overcommit = sb->running + sb->waiting - sb->concurrency;
103         if (overcommit == 0) {
104                 /* do nothing */
105         } else if (overcommit > 0) {
106                 if (overcommit > sb->waiting) {
107                         sb->quarantined += sb->waiting;
108                         sb->waiting      = 0;
109                 } else {
110                         sb->quarantined += overcommit;
111                         sb->waiting     -= overcommit;
112                 }
113         } else {
114                 unsigned undercommit = -overcommit;
115
116                 if (undercommit < sb->quarantined) {
117                         sb->waiting     += undercommit;
118                         sb->quarantined -= undercommit;
119                 } else if (sb->quarantined) {
120                         sb->waiting     += sb->quarantined;
121                         sb->quarantined  = 0;
122                 } else if (sb->waiting == 0 && sb->parked) {
123                         if (!timer_pending(&sb->timer)) {
124                                 mod_timer(&sb->timer, jiffies +
125                                           PWQR_UNPARK_DELAY);
126                         }
127                         return;
128                 }
129         }
130
131         if (timer_pending(&sb->timer))
132                 del_timer(&sb->timer);
133 }
134
135 static void pwqr_sb_timer_cb(unsigned long arg)
136 {
137         struct pwqr_sb *sb = (struct pwqr_sb *)arg;
138         unsigned long flags;
139
140         pwqr_sb_lock_irqsave(sb, flags);
141         if (sb->waiting == 0 && sb->parked && sb->running < sb->concurrency) {
142                 if (sb->overcommit_wakes == 0)
143                         wake_up_locked(&sb->wqh);
144         }
145         pwqr_sb_unlock_irqrestore(sb, flags);
146 }
147
148 static struct pwqr_sb *pwqr_sb_create(void)
149 {
150         struct pwqr_sb *sb;
151
152         sb = kzalloc(sizeof(struct pwqr_sb), GFP_KERNEL);
153         if (sb == NULL)
154                 return ERR_PTR(-ENOMEM);
155
156         kref_init(&sb->kref);
157         init_waitqueue_head(&sb->wqh);
158         sb->tgid        = current->tgid;
159         sb->concurrency = num_online_cpus();
160         init_timer(&sb->timer);
161         sb->timer.function = pwqr_sb_timer_cb;
162         sb->timer.data     = (unsigned long)sb;
163
164         __module_get(THIS_MODULE);
165         return sb;
166 }
167 static inline void pwqr_sb_get(struct pwqr_sb *sb)
168 {
169         kref_get(&sb->kref);
170 }
171
172 static void pwqr_sb_finalize(struct rcu_head *rcu)
173 {
174         struct pwqr_sb *sb = container_of(rcu, struct pwqr_sb, rcu);
175
176         module_put(THIS_MODULE);
177         kfree(sb);
178 }
179
180 static void pwqr_sb_release(struct kref *kref)
181 {
182         struct pwqr_sb *sb = container_of(kref, struct pwqr_sb, kref);
183
184         del_timer_sync(&sb->timer);
185         call_rcu(&sb->rcu, pwqr_sb_finalize);
186 }
187 static inline void pwqr_sb_put(struct pwqr_sb *sb)
188 {
189         kref_put(&sb->kref, pwqr_sb_release);
190 }
191
192 /*****************************************************************************
193  * tasks
194  */
195 static inline struct pwqr_task_bucket *task_hbucket(struct task_struct *task)
196 {
197         return &pwqr_tasks_hash[hash_ptr(task, PWQR_HASH_BITS)];
198 }
199
200 static struct pwqr_task *pwqr_task_find(struct task_struct *task)
201 {
202         struct pwqr_task_bucket *b = task_hbucket(task);
203         struct hlist_node *node;
204         struct pwqr_task *pwqt = NULL;
205
206         spin_lock(&b->lock);
207         hlist_for_each_entry(pwqt, node, &b->tasks, link) {
208                 if (pwqt->task == task)
209                         break;
210         }
211         spin_unlock(&b->lock);
212         return pwqt;
213 }
214
215 static struct pwqr_task *pwqr_task_create(struct task_struct *task)
216 {
217         struct pwqr_task_bucket *b = task_hbucket(task);
218         struct pwqr_task *pwqt;
219
220         pwqt = kmalloc(sizeof(*pwqt), GFP_KERNEL);
221         if (pwqt == NULL)
222                 return ERR_PTR(-ENOMEM);
223
224         preempt_notifier_init(&pwqt->notifier, &pwqr_preempt_running_ops);
225         preempt_notifier_register(&pwqt->notifier);
226         pwqt->task = task;
227
228         spin_lock(&b->lock);
229         hlist_add_head(&pwqt->link, &b->tasks);
230         spin_unlock(&b->lock);
231
232         return pwqt;
233 }
234
235 __cold
236 static void pwqr_task_detach(struct pwqr_task *pwqt, struct pwqr_sb *sb)
237 {
238         unsigned long flags;
239
240         pwqr_sb_lock_irqsave(sb, flags);
241         sb->registered--;
242         if (pwqt->notifier.ops == &pwqr_preempt_running_ops) {
243                 __pwqr_sb_update_state(sb, -1);
244         } else {
245                 __pwqr_sb_update_state(sb, 0);
246         }
247         pwqr_sb_unlock_irqrestore(sb, flags);
248         pwqr_sb_put(sb);
249         pwqt->sb = NULL;
250 }
251
252 __cold
253 static void pwqr_task_attach(struct pwqr_task *pwqt, struct pwqr_sb *sb)
254 {
255         unsigned long flags;
256
257         pwqr_sb_lock_irqsave(sb, flags);
258         pwqr_sb_get(pwqt->sb = sb);
259         sb->registered++;
260         __pwqr_sb_update_state(sb, 1);
261         pwqr_sb_unlock_irqrestore(sb, flags);
262 }
263
264 __cold
265 static void pwqr_task_release(struct pwqr_task *pwqt, bool from_notifier)
266 {
267         struct pwqr_task_bucket *b = task_hbucket(pwqt->task);
268
269         spin_lock(&b->lock);
270         hlist_del(&pwqt->link);
271         spin_unlock(&b->lock);
272         pwqt->notifier.ops = &pwqr_preempt_noop_ops;
273
274         if (from_notifier) {
275                 /* When called from sched_{out,in}, it's not allowed to
276                  * call preempt_notifier_unregister (or worse kfree())
277                  *
278                  * Though it's not a good idea to kfree() still registered
279                  * callbacks if we're not dying, it'll panic on the next
280                  * sched_{in,out} call.
281                  */
282                 BUG_ON(!(pwqt->task->state & TASK_DEAD));
283                 kfree_rcu(pwqt, rcu);
284         } else {
285                 preempt_notifier_unregister(&pwqt->notifier);
286                 kfree(pwqt);
287         }
288 }
289
290 static void pwqr_task_noop_sched_in(struct preempt_notifier *notifier, int cpu)
291 {
292 }
293
294 static void pwqr_task_noop_sched_out(struct preempt_notifier *notifier,
295                                     struct task_struct *next)
296 {
297 }
298
299 static void pwqr_task_blocked_sched_in(struct preempt_notifier *notifier, int cpu)
300 {
301         struct pwqr_task *pwqt = container_of(notifier, struct pwqr_task, notifier);
302         struct pwqr_sb   *sb   = pwqt->sb;
303         unsigned long flags;
304
305         if (unlikely(sb->dead)) {
306                 pwqr_task_detach(pwqt, sb);
307                 pwqr_task_release(pwqt, true);
308                 return;
309         }
310
311         pwqt->notifier.ops = &pwqr_preempt_running_ops;
312         pwqr_sb_lock_irqsave(sb, flags);
313         __pwqr_sb_update_state(sb, 1);
314         pwqr_sb_unlock_irqrestore(sb, flags);
315 }
316
317 static void pwqr_task_sched_out(struct preempt_notifier *notifier,
318                                struct task_struct *next)
319 {
320         struct pwqr_task    *pwqt = container_of(notifier, struct pwqr_task, notifier);
321         struct pwqr_sb      *sb   = pwqt->sb;
322         struct task_struct *p    = pwqt->task;
323
324         if (unlikely(p->state & TASK_DEAD) || unlikely(sb->dead)) {
325                 pwqr_task_detach(pwqt, sb);
326                 pwqr_task_release(pwqt, true);
327                 return;
328         }
329         if (p->state == 0 || (p->state & (__TASK_STOPPED | __TASK_TRACED)))
330                 return;
331
332         pwqt->notifier.ops = &pwqr_preempt_blocked_ops;
333         /* see preempt.h: irq are disabled for sched_out */
334         spin_lock(&sb->wqh.lock);
335         __pwqr_sb_update_state(sb, -1);
336         spin_unlock(&sb->wqh.lock);
337 }
338
339 static struct preempt_ops __read_mostly pwqr_preempt_noop_ops = {
340         .sched_in       = pwqr_task_noop_sched_in,
341         .sched_out      = pwqr_task_noop_sched_out,
342 };
343
344 static struct preempt_ops __read_mostly pwqr_preempt_running_ops = {
345         .sched_in       = pwqr_task_noop_sched_in,
346         .sched_out      = pwqr_task_sched_out,
347 };
348
349 static struct preempt_ops __read_mostly pwqr_preempt_blocked_ops = {
350         .sched_in       = pwqr_task_blocked_sched_in,
351         .sched_out      = pwqr_task_sched_out,
352 };
353
354 /*****************************************************************************
355  * file descriptor
356  */
357 static int pwqr_open(struct inode *inode, struct file *filp)
358 {
359         struct pwqr_sb *sb;
360
361         sb = pwqr_sb_create();
362         if (IS_ERR(sb))
363                 return PTR_ERR(sb);
364         filp->private_data = sb;
365         return 0;
366 }
367
368 static int pwqr_release(struct inode *inode, struct file *filp)
369 {
370         struct pwqr_sb *sb = filp->private_data;
371         unsigned long flags;
372
373         pwqr_sb_lock_irqsave(sb, flags);
374         sb->dead = true;
375         pwqr_sb_unlock_irqrestore(sb, flags);
376         wake_up_all(&sb->wqh);
377         pwqr_sb_put(sb);
378         return 0;
379 }
380
381 static long
382 do_pwqr_wait(struct pwqr_sb *sb, struct pwqr_task *pwqt,
383             int in_pool, struct pwqr_ioc_wait __user *arg)
384 {
385         unsigned long flags;
386         struct pwqr_ioc_wait wait;
387         long rc = 0;
388         u32 uval;
389
390         preempt_notifier_unregister(&pwqt->notifier);
391
392         if (in_pool && copy_from_user(&wait, arg, sizeof(wait))) {
393                 rc = -EFAULT;
394                 goto out;
395         }
396
397         pwqr_sb_lock_irqsave(sb, flags);
398         if (sb->running + sb->waiting <= sb->concurrency) {
399                 if (in_pool) {
400                         while (probe_kernel_address(wait.pwqr_uaddr, uval)) {
401                                 pwqr_sb_unlock_irqrestore(sb, flags);
402                                 rc = get_user(uval, (u32 *)wait.pwqr_uaddr);
403                                 if (rc)
404                                         goto out;
405                                 pwqr_sb_lock_irqsave(sb, flags);
406                         }
407
408                         if (uval != (u32)wait.pwqr_ticket) {
409                                 rc = -EWOULDBLOCK;
410                                 goto out_unlock;
411                         }
412                 } else {
413                         BUG_ON(sb->quarantined != 0);
414                         goto out_unlock;
415                 }
416         }
417
418         /* @ see <wait_event_interruptible_exclusive_locked_irq> */
419         if (likely(!sb->dead)) {
420                 DEFINE_WAIT(__wait);
421
422                 __wait.flags |= WQ_FLAG_EXCLUSIVE;
423
424                 if (in_pool) {
425                         sb->waiting++;
426                         __add_wait_queue(&sb->wqh, &__wait);
427                 } else {
428                         sb->parked++;
429                         __add_wait_queue_tail(&sb->wqh, &__wait);
430                 }
431                 __pwqr_sb_update_state(sb, -1);
432                 set_current_state(TASK_INTERRUPTIBLE);
433
434                 do {
435                         if (sb->overcommit_wakes)
436                                 break;
437                         if (signal_pending(current)) {
438                                 rc = -ERESTARTSYS;
439                                 break;
440                         }
441                         spin_unlock_irq(&sb->wqh.lock);
442                         schedule();
443                         spin_lock_irq(&sb->wqh.lock);
444                         if (in_pool && sb->waiting)
445                                 break;
446                         if (sb->running + sb->waiting < sb->concurrency)
447                                 break;
448                 } while (likely(!sb->dead));
449
450                 __remove_wait_queue(&sb->wqh, &__wait);
451                 __set_current_state(TASK_RUNNING);
452
453                 if (in_pool) {
454                         if (sb->waiting) {
455                                 sb->waiting--;
456                         } else {
457                                 BUG_ON(!sb->quarantined);
458                                 sb->quarantined--;
459                         }
460                 } else {
461                         sb->parked--;
462                 }
463                 __pwqr_sb_update_state(sb, 1);
464                 if (sb->overcommit_wakes)
465                         sb->overcommit_wakes--;
466                 if (sb->waiting + sb->running > sb->concurrency)
467                         rc = -EDQUOT;
468         }
469
470 out_unlock:
471         if (unlikely(sb->dead))
472                 rc = -EBADFD;
473         pwqr_sb_unlock_irqrestore(sb, flags);
474 out:
475         preempt_notifier_register(&pwqt->notifier);
476         return rc;
477 }
478
479 static long do_pwqr_unregister(struct pwqr_sb *sb, struct pwqr_task *pwqt)
480 {
481         if (!pwqt)
482                 return -EINVAL;
483         if (pwqt->sb != sb)
484                 return -ENOENT;
485         pwqr_task_detach(pwqt, sb);
486         pwqr_task_release(pwqt, false);
487         return 0;
488 }
489
490 static long do_pwqr_set_conc(struct pwqr_sb *sb, int conc)
491 {
492         long old_conc = sb->concurrency;
493         unsigned long flags;
494
495         pwqr_sb_lock_irqsave(sb, flags);
496         if (conc <= 0)
497                 conc = num_online_cpus();
498         if (conc != old_conc) {
499                 sb->concurrency = conc;
500                 __pwqr_sb_update_state(sb, 0);
501         }
502         pwqr_sb_unlock_irqrestore(sb, flags);
503
504         return old_conc;
505 }
506
507 static long do_pwqr_wake(struct pwqr_sb *sb, int oc, int count)
508 {
509         unsigned long flags;
510         int nwake;
511
512         if (count < 0)
513                 return -EINVAL;
514
515         pwqr_sb_lock_irqsave(sb, flags);
516
517         if (oc) {
518                 nwake = sb->waiting + sb->quarantined + sb->parked - sb->overcommit_wakes;
519                 if (count > nwake) {
520                         count = nwake;
521                 } else {
522                         nwake = count;
523                 }
524                 sb->overcommit_wakes += count;
525         } else if (sb->running + sb->overcommit_wakes < sb->concurrency) {
526                 nwake = sb->concurrency - sb->overcommit_wakes - sb->running;
527                 if (count > nwake) {
528                         count = nwake;
529                 } else {
530                         nwake = count;
531                 }
532         } else {
533                 /*
534                  * This codepath deserves an explanation: when the thread is
535                  * quarantined, for us, really, it's already "parked". Though
536                  * userland doesn't know about, so wake as many threads as
537                  * userlands would have liked to, and let the wakeup tell
538                  * userland those should be parked.
539                  *
540                  * That's why we lie about the number of woken threads,
541                  * really, userlandwise we woke up a thread so that it could
542                  * be parked for real and avoid spurious syscalls. So it's as
543                  * if we woke up 0 threads.
544                  */
545                 nwake = sb->quarantined;
546                 if (sb->waiting < sb->overcommit_wakes)
547                         nwake -= sb->overcommit_wakes - sb->waiting;
548                 if (nwake > count)
549                         nwake = count;
550                 count = 0;
551         }
552         while (nwake-- > 0)
553                 wake_up_locked(&sb->wqh);
554         pwqr_sb_unlock_irqrestore(sb, flags);
555
556         return count;
557 }
558
559 static long pwqr_ioctl(struct file *filp, unsigned command, unsigned long arg)
560 {
561         struct pwqr_sb      *sb   = filp->private_data;
562         struct task_struct *task = current;
563         struct pwqr_task    *pwqt;
564         int rc = 0;
565
566         if (sb->tgid != current->tgid)
567                 return -EBADFD;
568
569         switch (command) {
570         case PWQR_GET_CONC:
571                 return sb->concurrency;
572         case PWQR_SET_CONC:
573                 return do_pwqr_set_conc(sb, (int)arg);
574
575         case PWQR_WAKE:
576         case PWQR_WAKE_OC:
577                 return do_pwqr_wake(sb, command == PWQR_WAKE_OC, (int)arg);
578
579         case PWQR_WAIT:
580         case PWQR_PARK:
581         case PWQR_REGISTER:
582         case PWQR_UNREGISTER:
583                 break;
584         default:
585                 return -EINVAL;
586         }
587
588         pwqt = pwqr_task_find(task);
589         if (command == PWQR_UNREGISTER)
590                 return do_pwqr_unregister(sb, pwqt);
591
592         if (pwqt == NULL) {
593                 pwqt = pwqr_task_create(task);
594                 if (IS_ERR(pwqt))
595                         return PTR_ERR(pwqt);
596                 pwqr_task_attach(pwqt, sb);
597         } else if (unlikely(pwqt->sb != sb)) {
598                 pwqr_task_detach(pwqt, pwqt->sb);
599                 pwqr_task_attach(pwqt, sb);
600         }
601
602         switch (command) {
603         case PWQR_WAIT:
604                 rc = do_pwqr_wait(sb, pwqt, true, (struct pwqr_ioc_wait __user *)arg);
605                 break;
606         case PWQR_PARK:
607                 rc = do_pwqr_wait(sb, pwqt, false, NULL);
608                 break;
609         }
610
611         if (unlikely(sb->dead)) {
612                 pwqr_task_detach(pwqt, pwqt->sb);
613                 return -EBADFD;
614         }
615         return rc;
616 }
617
618 static const struct file_operations pwqr_dev_fops = {
619         .owner          = THIS_MODULE,
620         .open           = pwqr_open,
621         .release        = pwqr_release,
622         .unlocked_ioctl = pwqr_ioctl,
623 #ifdef CONFIG_COMPAT
624         .compat_ioctl   = pwqr_ioctl,
625 #endif
626 };
627
628 /*****************************************************************************
629  * module
630  */
631 static int __init pwqr_start(void)
632 {
633         int i;
634
635         for (i = 0; i < PWQR_HASH_SIZE; i++) {
636                 spin_lock_init(&pwqr_tasks_hash[i].lock);
637                 INIT_HLIST_HEAD(&pwqr_tasks_hash[i].tasks);
638         }
639
640         /* Register as a character device */
641         pwqr_major = register_chrdev(0, "pwqr", &pwqr_dev_fops);
642         if (pwqr_major < 0) {
643                 printk(KERN_ERR "pwqr: register_chrdev() failed\n");
644                 return pwqr_major;
645         }
646
647         /* Create a device node */
648         pwqr_class = class_create(THIS_MODULE, PWQR_DEVICE_NAME);
649         if (IS_ERR(pwqr_class)) {
650                 printk(KERN_ERR "pwqr: Error creating raw class\n");
651                 unregister_chrdev(pwqr_major, PWQR_DEVICE_NAME);
652                 return PTR_ERR(pwqr_class);
653         }
654         device_create(pwqr_class, NULL, MKDEV(pwqr_major, 0), NULL, PWQR_DEVICE_NAME);
655         printk(KERN_INFO "pwqr: PThreads Work Queues Regulator v1 loaded");
656         return 0;
657 }
658
659 static void __exit pwqr_end(void)
660 {
661         rcu_barrier();
662         device_destroy(pwqr_class, MKDEV(pwqr_major, 0));
663         class_destroy(pwqr_class);
664         unregister_chrdev(pwqr_major, PWQR_DEVICE_NAME);
665 }
666
667 module_init(pwqr_start);
668 module_exit(pwqr_end);
669
670 MODULE_LICENSE("GPL");
671 MODULE_AUTHOR("Pierre Habouzit <pierre.habouzit@intersec.com>");
672 MODULE_DESCRIPTION("PThreads Work Queues Regulator");
673
674 // vim:noet:sw=8:cinoptions+=\:0,L-1,=1s: