Prepare code to plug the overcommit notification.
[~madcoder/pwqr.git] / kernel / pwqr.c
1 /*
2  * Copyright (C) 2012   Pierre Habouzit <pierre.habouzit@intersec.com>
3  * Copyright (C) 2012   Intersec SAS
4  *
5  * This file implements the Linux Pthread Workqueue Regulator, and is part
6  * of the linux kernel.
7  *
8  * The Linux Kernel is free software: you can redistribute it and/or modify it
9  * under the terms of the GNU General Public License version 2 as published by
10  * the Free Software Foundation.
11  *
12  * The Linux Kernel is distributed in the hope that it will be useful, but
13  * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
14  * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
15  * License for more details.
16  *
17  * You should have received a copy of the GNU General Public License version 2
18  * along with The Linux Kernel.  If not, see <http://www.gnu.org/licenses/>.
19  */
20
21 #include <linux/cdev.h>
22 #include <linux/device.h>
23 #include <linux/file.h>
24 #include <linux/fs.h>
25 #include <linux/hash.h>
26 #include <linux/init.h>
27 #include <linux/kref.h>
28 #include <linux/module.h>
29 #include <linux/sched.h>
30 #include <linux/slab.h>
31 #include <linux/spinlock.h>
32 #include <linux/timer.h>
33 #include <linux/uaccess.h>
34 #include <linux/wait.h>
35
36 #ifndef CONFIG_PREEMPT_NOTIFIERS
37 #  error PWQ module requires CONFIG_PREEMPT_NOTIFIERS
38 #endif
39
40 #include "pwqr.h"
41
42 #define PWQR_HASH_BITS          5
43 #define PWQR_HASH_SIZE          (1 << PWQR_HASH_BITS)
44
45 #define PWQR_UC_DELAY           (HZ / 10)
46 #define PWQR_OC_DELAY           (HZ / 20)
47
48 #define PWQR_STATE_NONE         0
49 #define PWQR_STATE_UC           1
50 #define PWQR_STATE_OC           2
51 #define PWQR_STATE_DEAD         (-1)
52
53 struct pwqr_task_bucket {
54         spinlock_t              lock;
55         struct hlist_head       tasks;
56 };
57
58 struct pwqr_sb {
59         struct kref             kref;
60         struct rcu_head         rcu;
61         struct timer_list       timer;
62         wait_queue_head_t       wqh;
63         pid_t                   tgid;
64
65         unsigned                concurrency;
66         unsigned                registered;
67
68         unsigned                running;
69         unsigned                waiting;
70         unsigned                parked;
71         unsigned                overcommit_wakes;
72
73         int                     state;
74 };
75
76 struct pwqr_task {
77         struct preempt_notifier notifier;
78         struct hlist_node       link;
79         struct rcu_head         rcu;
80         struct task_struct     *task;
81         struct pwqr_sb         *sb;
82 };
83
84 /*
85  * Global variables
86  */
87 static struct class            *pwqr_class;
88 static int                      pwqr_major;
89 static struct pwqr_task_bucket  pwqr_tasks_hash[PWQR_HASH_SIZE];
90 static struct preempt_ops       pwqr_preempt_running_ops;
91 static struct preempt_ops       pwqr_preempt_blocked_ops;
92 static struct preempt_ops       pwqr_preempt_noop_ops;
93
94 /*****************************************************************************
95  * Scoreboards
96  */
97
98 #define pwqr_sb_lock_irqsave(sb, flags) \
99         spin_lock_irqsave(&(sb)->wqh.lock, flags)
100 #define pwqr_sb_unlock_irqrestore(sb, flags) \
101         spin_unlock_irqrestore(&(sb)->wqh.lock, flags)
102
103 static inline void pwqr_arm_timer(struct pwqr_sb *sb, int how, int delay)
104 {
105         if (timer_pending(&sb->timer) && sb->state == how)
106                 return;
107         mod_timer(&sb->timer, jiffies + delay);
108         sb->state = how;
109 }
110
111 static inline void __pwqr_sb_update_state(struct pwqr_sb *sb, int running_delta)
112 {
113         sb->running += running_delta;
114
115         if (sb->running < sb->concurrency && sb->waiting == 0 && sb->parked) {
116                 pwqr_arm_timer(sb, PWQR_STATE_UC, PWQR_UC_DELAY);
117         } else if (sb->running > sb->concurrency) {
118                 pwqr_arm_timer(sb, PWQR_STATE_OC, PWQR_OC_DELAY);
119         } else {
120                 sb->state = PWQR_STATE_NONE;
121                 if (!timer_pending(&sb->timer))
122                         del_timer(&sb->timer);
123         }
124 }
125
126 static void pwqr_sb_timer_cb(unsigned long arg)
127 {
128         struct pwqr_sb *sb = (struct pwqr_sb *)arg;
129         unsigned long flags;
130
131         pwqr_sb_lock_irqsave(sb, flags);
132         if (sb->running < sb->concurrency && sb->waiting == 0 && sb->parked) {
133                 if (sb->overcommit_wakes == 0)
134                         wake_up_locked(&sb->wqh);
135         }
136         if (sb->running > sb->concurrency) {
137                 /* See ../Documentation/pwqr.adoc */
138         }
139         pwqr_sb_unlock_irqrestore(sb, flags);
140 }
141
142 static struct pwqr_sb *pwqr_sb_create(void)
143 {
144         struct pwqr_sb *sb;
145
146         sb = kzalloc(sizeof(struct pwqr_sb), GFP_KERNEL);
147         if (sb == NULL)
148                 return ERR_PTR(-ENOMEM);
149
150         kref_init(&sb->kref);
151         init_waitqueue_head(&sb->wqh);
152         sb->tgid        = current->tgid;
153         sb->concurrency = num_online_cpus();
154         init_timer(&sb->timer);
155         sb->timer.function = pwqr_sb_timer_cb;
156         sb->timer.data     = (unsigned long)sb;
157
158         __module_get(THIS_MODULE);
159         return sb;
160 }
161 static inline void pwqr_sb_get(struct pwqr_sb *sb)
162 {
163         kref_get(&sb->kref);
164 }
165
166 static void pwqr_sb_finalize(struct rcu_head *rcu)
167 {
168         struct pwqr_sb *sb = container_of(rcu, struct pwqr_sb, rcu);
169
170         module_put(THIS_MODULE);
171         kfree(sb);
172 }
173
174 static void pwqr_sb_release(struct kref *kref)
175 {
176         struct pwqr_sb *sb = container_of(kref, struct pwqr_sb, kref);
177
178         del_timer_sync(&sb->timer);
179         call_rcu(&sb->rcu, pwqr_sb_finalize);
180 }
181 static inline void pwqr_sb_put(struct pwqr_sb *sb)
182 {
183         kref_put(&sb->kref, pwqr_sb_release);
184 }
185
186 /*****************************************************************************
187  * tasks
188  */
189 static inline struct pwqr_task_bucket *task_hbucket(struct task_struct *task)
190 {
191         return &pwqr_tasks_hash[hash_ptr(task, PWQR_HASH_BITS)];
192 }
193
194 static struct pwqr_task *pwqr_task_find(struct task_struct *task)
195 {
196         struct pwqr_task_bucket *b = task_hbucket(task);
197         struct hlist_node *node;
198         struct pwqr_task *pwqt = NULL;
199
200         spin_lock(&b->lock);
201         hlist_for_each_entry(pwqt, node, &b->tasks, link) {
202                 if (pwqt->task == task)
203                         break;
204         }
205         spin_unlock(&b->lock);
206         return pwqt;
207 }
208
209 static struct pwqr_task *pwqr_task_create(struct task_struct *task)
210 {
211         struct pwqr_task_bucket *b = task_hbucket(task);
212         struct pwqr_task *pwqt;
213
214         pwqt = kmalloc(sizeof(*pwqt), GFP_KERNEL);
215         if (pwqt == NULL)
216                 return ERR_PTR(-ENOMEM);
217
218         preempt_notifier_init(&pwqt->notifier, &pwqr_preempt_running_ops);
219         preempt_notifier_register(&pwqt->notifier);
220         pwqt->task = task;
221
222         spin_lock(&b->lock);
223         hlist_add_head(&pwqt->link, &b->tasks);
224         spin_unlock(&b->lock);
225
226         return pwqt;
227 }
228
229 __cold
230 static void pwqr_task_detach(struct pwqr_task *pwqt, struct pwqr_sb *sb)
231 {
232         unsigned long flags;
233
234         pwqr_sb_lock_irqsave(sb, flags);
235         sb->registered--;
236         if (pwqt->notifier.ops == &pwqr_preempt_running_ops) {
237                 __pwqr_sb_update_state(sb, -1);
238         } else {
239                 __pwqr_sb_update_state(sb, 0);
240         }
241         pwqr_sb_unlock_irqrestore(sb, flags);
242         pwqr_sb_put(sb);
243         pwqt->sb = NULL;
244 }
245
246 __cold
247 static void pwqr_task_attach(struct pwqr_task *pwqt, struct pwqr_sb *sb)
248 {
249         unsigned long flags;
250
251         pwqr_sb_lock_irqsave(sb, flags);
252         pwqr_sb_get(pwqt->sb = sb);
253         sb->registered++;
254         __pwqr_sb_update_state(sb, 1);
255         pwqr_sb_unlock_irqrestore(sb, flags);
256 }
257
258 __cold
259 static void pwqr_task_release(struct pwqr_task *pwqt, bool from_notifier)
260 {
261         struct pwqr_task_bucket *b = task_hbucket(pwqt->task);
262
263         spin_lock(&b->lock);
264         hlist_del(&pwqt->link);
265         spin_unlock(&b->lock);
266         pwqt->notifier.ops = &pwqr_preempt_noop_ops;
267
268         if (from_notifier) {
269                 /* When called from sched_{out,in}, it's not allowed to
270                  * call preempt_notifier_unregister (or worse kfree())
271                  *
272                  * Though it's not a good idea to kfree() still registered
273                  * callbacks if we're not dying, it'll panic on the next
274                  * sched_{in,out} call.
275                  */
276                 BUG_ON(!(pwqt->task->state & TASK_DEAD));
277                 kfree_rcu(pwqt, rcu);
278         } else {
279                 preempt_notifier_unregister(&pwqt->notifier);
280                 kfree(pwqt);
281         }
282 }
283
284 static void pwqr_task_noop_sched_in(struct preempt_notifier *notifier, int cpu)
285 {
286 }
287
288 static void pwqr_task_noop_sched_out(struct preempt_notifier *notifier,
289                                     struct task_struct *next)
290 {
291 }
292
293 static void pwqr_task_blocked_sched_in(struct preempt_notifier *notifier, int cpu)
294 {
295         struct pwqr_task *pwqt = container_of(notifier, struct pwqr_task, notifier);
296         struct pwqr_sb   *sb   = pwqt->sb;
297         unsigned long flags;
298
299         if (unlikely(sb->state < 0)) {
300                 pwqr_task_detach(pwqt, sb);
301                 pwqr_task_release(pwqt, true);
302                 return;
303         }
304
305         pwqt->notifier.ops = &pwqr_preempt_running_ops;
306         pwqr_sb_lock_irqsave(sb, flags);
307         __pwqr_sb_update_state(sb, 1);
308         pwqr_sb_unlock_irqrestore(sb, flags);
309 }
310
311 static void pwqr_task_sched_out(struct preempt_notifier *notifier,
312                                struct task_struct *next)
313 {
314         struct pwqr_task   *pwqt = container_of(notifier, struct pwqr_task, notifier);
315         struct pwqr_sb     *sb   = pwqt->sb;
316         struct task_struct *p    = pwqt->task;
317
318         if (unlikely(p->state & TASK_DEAD) || unlikely(sb->state < 0)) {
319                 pwqr_task_detach(pwqt, sb);
320                 pwqr_task_release(pwqt, true);
321                 return;
322         }
323         if (p->state == 0 || (p->state & (__TASK_STOPPED | __TASK_TRACED)))
324                 return;
325
326         pwqt->notifier.ops = &pwqr_preempt_blocked_ops;
327         /* see preempt.h: irq are disabled for sched_out */
328         spin_lock(&sb->wqh.lock);
329         __pwqr_sb_update_state(sb, -1);
330         spin_unlock(&sb->wqh.lock);
331 }
332
333 static struct preempt_ops __read_mostly pwqr_preempt_noop_ops = {
334         .sched_in       = pwqr_task_noop_sched_in,
335         .sched_out      = pwqr_task_noop_sched_out,
336 };
337
338 static struct preempt_ops __read_mostly pwqr_preempt_running_ops = {
339         .sched_in       = pwqr_task_noop_sched_in,
340         .sched_out      = pwqr_task_sched_out,
341 };
342
343 static struct preempt_ops __read_mostly pwqr_preempt_blocked_ops = {
344         .sched_in       = pwqr_task_blocked_sched_in,
345         .sched_out      = pwqr_task_sched_out,
346 };
347
348 /*****************************************************************************
349  * file descriptor
350  */
351 static int pwqr_open(struct inode *inode, struct file *filp)
352 {
353         struct pwqr_sb *sb;
354
355         sb = pwqr_sb_create();
356         if (IS_ERR(sb))
357                 return PTR_ERR(sb);
358         filp->private_data = sb;
359         return 0;
360 }
361
362 static int pwqr_release(struct inode *inode, struct file *filp)
363 {
364         struct pwqr_sb *sb = filp->private_data;
365         unsigned long flags;
366
367         pwqr_sb_lock_irqsave(sb, flags);
368         sb->state = PWQR_STATE_DEAD;
369         pwqr_sb_unlock_irqrestore(sb, flags);
370         wake_up_all(&sb->wqh);
371         pwqr_sb_put(sb);
372         return 0;
373 }
374
375 static long
376 do_pwqr_wait(struct pwqr_sb *sb, struct pwqr_task *pwqt,
377             int is_wait, struct pwqr_ioc_wait __user *arg)
378 {
379         unsigned long flags;
380         struct pwqr_ioc_wait wait;
381         long rc = 0;
382         u32 uval;
383
384         preempt_notifier_unregister(&pwqt->notifier);
385
386         if (is_wait) {
387                 if (copy_from_user(&wait, arg, sizeof(wait))) {
388                         rc = -EFAULT;
389                         goto out;
390                 }
391                 if (unlikely((long)wait.pwqr_uaddr % sizeof(int) != 0)) {
392                         rc = -EINVAL;
393                         goto out;
394                 }
395         }
396
397         pwqr_sb_lock_irqsave(sb, flags);
398         if (sb->running + sb->waiting <= sb->concurrency) {
399                 if (is_wait) {
400                         while (probe_kernel_address(wait.pwqr_uaddr, uval)) {
401                                 pwqr_sb_unlock_irqrestore(sb, flags);
402                                 rc = get_user(uval, (u32 *)wait.pwqr_uaddr);
403                                 if (rc)
404                                         goto out;
405                                 pwqr_sb_lock_irqsave(sb, flags);
406                         }
407
408                         if (uval != (u32)wait.pwqr_ticket) {
409                                 rc = -EWOULDBLOCK;
410                                 goto out_unlock;
411                         }
412                 } else {
413                         goto out_unlock;
414                 }
415         }
416
417         /* @ see <wait_event_interruptible_exclusive_locked_irq> */
418         if (likely(sb->state >= 0)) {
419                 DEFINE_WAIT(__wait);
420
421                 __wait.flags |= WQ_FLAG_EXCLUSIVE;
422
423                 if (is_wait) {
424                         sb->waiting++;
425                         __add_wait_queue(&sb->wqh, &__wait);
426                 } else {
427                         sb->parked++;
428                         __add_wait_queue_tail(&sb->wqh, &__wait);
429                 }
430                 __pwqr_sb_update_state(sb, -1);
431                 set_current_state(TASK_INTERRUPTIBLE);
432
433                 do {
434                         if (sb->overcommit_wakes)
435                                 break;
436                         if (signal_pending(current)) {
437                                 rc = -ERESTARTSYS;
438                                 break;
439                         }
440                         spin_unlock_irq(&sb->wqh.lock);
441                         schedule();
442                         spin_lock_irq(&sb->wqh.lock);
443                         if (is_wait)
444                                 break;
445                         if (sb->running + sb->waiting < sb->concurrency)
446                                 break;
447                 } while (likely(sb->state >= 0));
448
449                 __remove_wait_queue(&sb->wqh, &__wait);
450                 __set_current_state(TASK_RUNNING);
451
452                 if (is_wait) {
453                         sb->waiting--;
454                 } else {
455                         sb->parked--;
456                 }
457                 __pwqr_sb_update_state(sb, 1);
458                 if (sb->overcommit_wakes)
459                         sb->overcommit_wakes--;
460                 if (sb->waiting + sb->running > sb->concurrency)
461                         rc = -EDQUOT;
462         }
463
464 out_unlock:
465         if (unlikely(sb->state < 0))
466                 rc = -EBADFD;
467         pwqr_sb_unlock_irqrestore(sb, flags);
468 out:
469         preempt_notifier_register(&pwqt->notifier);
470         return rc;
471 }
472
473 static long do_pwqr_unregister(struct pwqr_sb *sb, struct pwqr_task *pwqt)
474 {
475         if (!pwqt)
476                 return -EINVAL;
477         if (pwqt->sb != sb)
478                 return -ENOENT;
479         pwqr_task_detach(pwqt, sb);
480         pwqr_task_release(pwqt, false);
481         return 0;
482 }
483
484 static long do_pwqr_set_conc(struct pwqr_sb *sb, int conc)
485 {
486         long old_conc = sb->concurrency;
487         unsigned long flags;
488
489         pwqr_sb_lock_irqsave(sb, flags);
490         if (conc <= 0)
491                 conc = num_online_cpus();
492         if (conc != old_conc) {
493                 sb->concurrency = conc;
494                 __pwqr_sb_update_state(sb, 0);
495         }
496         pwqr_sb_unlock_irqrestore(sb, flags);
497
498         return old_conc;
499 }
500
501 static long do_pwqr_wake(struct pwqr_sb *sb, int oc, int count)
502 {
503         unsigned long flags;
504         int nwake;
505
506         if (count < 0)
507                 return -EINVAL;
508
509         pwqr_sb_lock_irqsave(sb, flags);
510
511         if (oc) {
512                 nwake = sb->waiting + sb->parked - sb->overcommit_wakes;
513                 if (count > nwake) {
514                         count = nwake;
515                 } else {
516                         nwake = count;
517                 }
518                 sb->overcommit_wakes += count;
519         } else if (sb->running + sb->overcommit_wakes < sb->concurrency) {
520                 nwake = sb->concurrency - sb->overcommit_wakes - sb->running;
521                 if (nwake > sb->waiting + sb->parked - sb->overcommit_wakes) {
522                         nwake = sb->waiting + sb->parked -
523                                 sb->overcommit_wakes;
524                 }
525                 if (count > nwake) {
526                         count = nwake;
527                 } else {
528                         nwake = count;
529                 }
530         } else {
531                 /*
532                  * This codepath deserves an explanation: waking the thread
533                  * "for real" would overcommit, though userspace KNOWS there
534                  * is at least one waiting thread. Such threads are threads
535                  * that are "quarantined".
536                  *
537                  * Quarantined threads are woken up one by one, to allow a
538                  * slow ramp down, trying to minimize "waiting" <-> "parked"
539                  * flip-flops, no matter how many wakes have been asked.
540                  *
541                  * Since releasing one quarantined thread will wake up a
542                  * thread that will (almost) straight go to parked mode, lie
543                  * to userland about the fact that we unblocked that thread,
544                  * and return 0.
545                  *
546                  * Though if we're already waking all waiting threads for
547                  * overcommitting jobs, well, we don't need that.
548                  */
549                 count = 0;
550                 nwake = sb->waiting > sb->overcommit_wakes;
551         }
552         while (nwake-- > 0)
553                 wake_up_locked(&sb->wqh);
554         pwqr_sb_unlock_irqrestore(sb, flags);
555
556         return count;
557 }
558
559 static long pwqr_ioctl(struct file *filp, unsigned command, unsigned long arg)
560 {
561         struct pwqr_sb     *sb   = filp->private_data;
562         struct task_struct *task = current;
563         struct pwqr_task   *pwqt;
564         int rc = 0;
565
566         if (sb->tgid != current->tgid)
567                 return -EBADFD;
568
569         switch (command) {
570         case PWQR_GET_CONC:
571                 return sb->concurrency;
572         case PWQR_SET_CONC:
573                 return do_pwqr_set_conc(sb, (int)arg);
574
575         case PWQR_WAKE:
576         case PWQR_WAKE_OC:
577                 return do_pwqr_wake(sb, command == PWQR_WAKE_OC, (int)arg);
578
579         case PWQR_WAIT:
580         case PWQR_PARK:
581         case PWQR_REGISTER:
582         case PWQR_UNREGISTER:
583                 break;
584         default:
585                 return -EINVAL;
586         }
587
588         pwqt = pwqr_task_find(task);
589         if (command == PWQR_UNREGISTER)
590                 return do_pwqr_unregister(sb, pwqt);
591
592         if (pwqt == NULL) {
593                 pwqt = pwqr_task_create(task);
594                 if (IS_ERR(pwqt))
595                         return PTR_ERR(pwqt);
596                 pwqr_task_attach(pwqt, sb);
597         } else if (unlikely(pwqt->sb != sb)) {
598                 pwqr_task_detach(pwqt, pwqt->sb);
599                 pwqr_task_attach(pwqt, sb);
600         }
601
602         switch (command) {
603         case PWQR_WAIT:
604                 rc = do_pwqr_wait(sb, pwqt, true, (struct pwqr_ioc_wait __user *)arg);
605                 break;
606         case PWQR_PARK:
607                 rc = do_pwqr_wait(sb, pwqt, false, NULL);
608                 break;
609         }
610
611         if (unlikely(sb->state < 0)) {
612                 pwqr_task_detach(pwqt, pwqt->sb);
613                 return -EBADFD;
614         }
615         return rc;
616 }
617
618 static const struct file_operations pwqr_dev_fops = {
619         .owner          = THIS_MODULE,
620         .open           = pwqr_open,
621         .release        = pwqr_release,
622         .unlocked_ioctl = pwqr_ioctl,
623 #ifdef CONFIG_COMPAT
624         .compat_ioctl   = pwqr_ioctl,
625 #endif
626 };
627
628 /*****************************************************************************
629  * module
630  */
631 static int __init pwqr_start(void)
632 {
633         int i;
634
635         for (i = 0; i < PWQR_HASH_SIZE; i++) {
636                 spin_lock_init(&pwqr_tasks_hash[i].lock);
637                 INIT_HLIST_HEAD(&pwqr_tasks_hash[i].tasks);
638         }
639
640         /* Register as a character device */
641         pwqr_major = register_chrdev(0, "pwqr", &pwqr_dev_fops);
642         if (pwqr_major < 0) {
643                 printk(KERN_ERR "pwqr: register_chrdev() failed\n");
644                 return pwqr_major;
645         }
646
647         /* Create a device node */
648         pwqr_class = class_create(THIS_MODULE, PWQR_DEVICE_NAME);
649         if (IS_ERR(pwqr_class)) {
650                 printk(KERN_ERR "pwqr: Error creating raw class\n");
651                 unregister_chrdev(pwqr_major, PWQR_DEVICE_NAME);
652                 return PTR_ERR(pwqr_class);
653         }
654         device_create(pwqr_class, NULL, MKDEV(pwqr_major, 0), NULL, PWQR_DEVICE_NAME);
655         printk(KERN_INFO "pwqr: PThreads Work Queues Regulator v1 loaded");
656         return 0;
657 }
658
659 static void __exit pwqr_end(void)
660 {
661         rcu_barrier();
662         device_destroy(pwqr_class, MKDEV(pwqr_major, 0));
663         class_destroy(pwqr_class);
664         unregister_chrdev(pwqr_major, PWQR_DEVICE_NAME);
665 }
666
667 module_init(pwqr_start);
668 module_exit(pwqr_end);
669
670 MODULE_LICENSE("GPL");
671 MODULE_AUTHOR("Pierre Habouzit <pierre.habouzit@intersec.com>");
672 MODULE_DESCRIPTION("PThreads Work Queues Regulator");
673
674 // vim:noet:sw=8:cinoptions+=\:0,L-1,=1s: