do not forbid other processes to use the pwqr fd.
[~madcoder/pwqr.git] / kernel / pwqr.c
1 /*
2  * Copyright (C) 2012   Pierre Habouzit <pierre.habouzit@intersec.com>
3  * Copyright (C) 2012   Intersec SAS
4  *
5  * This file implements the Linux Pthread Workqueue Regulator, and is part
6  * of the linux kernel.
7  *
8  * The Linux Kernel is free software: you can redistribute it and/or modify it
9  * under the terms of the GNU General Public License version 2 as published by
10  * the Free Software Foundation.
11  *
12  * The Linux Kernel is distributed in the hope that it will be useful, but
13  * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
14  * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
15  * License for more details.
16  *
17  * You should have received a copy of the GNU General Public License version 2
18  * along with The Linux Kernel.  If not, see <http://www.gnu.org/licenses/>.
19  */
20
21 #include <linux/cdev.h>
22 #include <linux/device.h>
23 #include <linux/file.h>
24 #include <linux/fs.h>
25 #include <linux/hash.h>
26 #include <linux/init.h>
27 #include <linux/kref.h>
28 #include <linux/module.h>
29 #include <linux/sched.h>
30 #include <linux/slab.h>
31 #include <linux/spinlock.h>
32 #include <linux/timer.h>
33 #include <linux/uaccess.h>
34 #include <linux/wait.h>
35
36 #ifndef CONFIG_PREEMPT_NOTIFIERS
37 #  error PWQ module requires CONFIG_PREEMPT_NOTIFIERS
38 #endif
39
40 #include "pwqr.h"
41
42 #define PWQR_HASH_BITS          5
43 #define PWQR_HASH_SIZE          (1 << PWQR_HASH_BITS)
44
45 #define PWQR_UC_DELAY           (HZ / 10)
46 #define PWQR_OC_DELAY           (HZ / 20)
47
48 #define PWQR_STATE_NONE         0
49 #define PWQR_STATE_UC           1
50 #define PWQR_STATE_OC           2
51 #define PWQR_STATE_DEAD         (-1)
52
53 struct pwqr_task_bucket {
54         spinlock_t              lock;
55         struct hlist_head       tasks;
56 };
57
58 struct pwqr_sb {
59         struct kref             kref;
60         struct rcu_head         rcu;
61         struct timer_list       timer;
62         wait_queue_head_t       wqh;
63
64         unsigned                concurrency;
65         unsigned                registered;
66
67         unsigned                running;
68         unsigned                waiting;
69         unsigned                parked;
70         unsigned                overcommit_wakes;
71
72         int                     state;
73 };
74
75 struct pwqr_task {
76         struct preempt_notifier notifier;
77         struct hlist_node       link;
78         struct rcu_head         rcu;
79         struct task_struct     *task;
80         struct pwqr_sb         *sb;
81 };
82
83 /*
84  * Global variables
85  */
86 static struct class            *pwqr_class;
87 static int                      pwqr_major;
88 static struct pwqr_task_bucket  pwqr_tasks_hash[PWQR_HASH_SIZE];
89 static struct preempt_ops       pwqr_preempt_running_ops;
90 static struct preempt_ops       pwqr_preempt_blocked_ops;
91 static struct preempt_ops       pwqr_preempt_noop_ops;
92
93 /*****************************************************************************
94  * Scoreboards
95  */
96
97 #define pwqr_sb_lock_irqsave(sb, flags) \
98         spin_lock_irqsave(&(sb)->wqh.lock, flags)
99 #define pwqr_sb_unlock_irqrestore(sb, flags) \
100         spin_unlock_irqrestore(&(sb)->wqh.lock, flags)
101
102 static inline void pwqr_arm_timer(struct pwqr_sb *sb, int how, int delay)
103 {
104         if (timer_pending(&sb->timer) && sb->state == how)
105                 return;
106         mod_timer(&sb->timer, jiffies + delay);
107         sb->state = how;
108 }
109
110 static inline void __pwqr_sb_update_state(struct pwqr_sb *sb, int running_delta)
111 {
112         sb->running += running_delta;
113
114         if (sb->running < sb->concurrency && sb->waiting == 0 && sb->parked) {
115                 pwqr_arm_timer(sb, PWQR_STATE_UC, PWQR_UC_DELAY);
116         } else if (sb->running > sb->concurrency) {
117                 pwqr_arm_timer(sb, PWQR_STATE_OC, PWQR_OC_DELAY);
118         } else {
119                 sb->state = PWQR_STATE_NONE;
120                 if (!timer_pending(&sb->timer))
121                         del_timer(&sb->timer);
122         }
123 }
124
125 static void pwqr_sb_timer_cb(unsigned long arg)
126 {
127         struct pwqr_sb *sb = (struct pwqr_sb *)arg;
128         unsigned long flags;
129
130         pwqr_sb_lock_irqsave(sb, flags);
131         if (sb->running < sb->concurrency && sb->waiting == 0 && sb->parked) {
132                 if (sb->overcommit_wakes == 0)
133                         wake_up_locked(&sb->wqh);
134         }
135         if (sb->running > sb->concurrency) {
136                 /* See ../Documentation/pwqr.adoc */
137         }
138         pwqr_sb_unlock_irqrestore(sb, flags);
139 }
140
141 static struct pwqr_sb *pwqr_sb_create(void)
142 {
143         struct pwqr_sb *sb;
144
145         sb = kzalloc(sizeof(struct pwqr_sb), GFP_KERNEL);
146         if (sb == NULL)
147                 return ERR_PTR(-ENOMEM);
148
149         kref_init(&sb->kref);
150         init_waitqueue_head(&sb->wqh);
151         sb->concurrency    = num_online_cpus();
152         init_timer(&sb->timer);
153         sb->timer.function = pwqr_sb_timer_cb;
154         sb->timer.data     = (unsigned long)sb;
155
156         __module_get(THIS_MODULE);
157         return sb;
158 }
159 static inline void pwqr_sb_get(struct pwqr_sb *sb)
160 {
161         kref_get(&sb->kref);
162 }
163
164 static void pwqr_sb_finalize(struct rcu_head *rcu)
165 {
166         struct pwqr_sb *sb = container_of(rcu, struct pwqr_sb, rcu);
167
168         module_put(THIS_MODULE);
169         kfree(sb);
170 }
171
172 static void pwqr_sb_release(struct kref *kref)
173 {
174         struct pwqr_sb *sb = container_of(kref, struct pwqr_sb, kref);
175
176         del_timer_sync(&sb->timer);
177         call_rcu(&sb->rcu, pwqr_sb_finalize);
178 }
179 static inline void pwqr_sb_put(struct pwqr_sb *sb)
180 {
181         kref_put(&sb->kref, pwqr_sb_release);
182 }
183
184 /*****************************************************************************
185  * tasks
186  */
187 static inline struct pwqr_task_bucket *task_hbucket(struct task_struct *task)
188 {
189         return &pwqr_tasks_hash[hash_ptr(task, PWQR_HASH_BITS)];
190 }
191
192 static struct pwqr_task *pwqr_task_find(struct task_struct *task)
193 {
194         struct pwqr_task_bucket *b = task_hbucket(task);
195         struct hlist_node *node;
196         struct pwqr_task *pwqt = NULL;
197
198         spin_lock(&b->lock);
199         hlist_for_each_entry(pwqt, node, &b->tasks, link) {
200                 if (pwqt->task == task)
201                         break;
202         }
203         spin_unlock(&b->lock);
204         return pwqt;
205 }
206
207 static struct pwqr_task *pwqr_task_create(struct task_struct *task)
208 {
209         struct pwqr_task_bucket *b = task_hbucket(task);
210         struct pwqr_task *pwqt;
211
212         pwqt = kmalloc(sizeof(*pwqt), GFP_KERNEL);
213         if (pwqt == NULL)
214                 return ERR_PTR(-ENOMEM);
215
216         preempt_notifier_init(&pwqt->notifier, &pwqr_preempt_running_ops);
217         preempt_notifier_register(&pwqt->notifier);
218         pwqt->task = task;
219
220         spin_lock(&b->lock);
221         hlist_add_head(&pwqt->link, &b->tasks);
222         spin_unlock(&b->lock);
223
224         return pwqt;
225 }
226
227 __cold
228 static void pwqr_task_detach(struct pwqr_task *pwqt, struct pwqr_sb *sb)
229 {
230         unsigned long flags;
231
232         pwqr_sb_lock_irqsave(sb, flags);
233         sb->registered--;
234         if (pwqt->notifier.ops == &pwqr_preempt_running_ops) {
235                 __pwqr_sb_update_state(sb, -1);
236         } else {
237                 __pwqr_sb_update_state(sb, 0);
238         }
239         pwqr_sb_unlock_irqrestore(sb, flags);
240         pwqr_sb_put(sb);
241         pwqt->sb = NULL;
242 }
243
244 __cold
245 static void pwqr_task_attach(struct pwqr_task *pwqt, struct pwqr_sb *sb)
246 {
247         unsigned long flags;
248
249         pwqr_sb_lock_irqsave(sb, flags);
250         pwqr_sb_get(pwqt->sb = sb);
251         sb->registered++;
252         __pwqr_sb_update_state(sb, 1);
253         pwqr_sb_unlock_irqrestore(sb, flags);
254 }
255
256 __cold
257 static void pwqr_task_release(struct pwqr_task *pwqt, bool from_notifier)
258 {
259         struct pwqr_task_bucket *b = task_hbucket(pwqt->task);
260
261         spin_lock(&b->lock);
262         hlist_del(&pwqt->link);
263         spin_unlock(&b->lock);
264         pwqt->notifier.ops = &pwqr_preempt_noop_ops;
265
266         if (from_notifier) {
267                 /* When called from sched_{out,in}, it's not allowed to
268                  * call preempt_notifier_unregister (or worse kfree())
269                  *
270                  * Though it's not a good idea to kfree() still registered
271                  * callbacks if we're not dying, it'll panic on the next
272                  * sched_{in,out} call.
273                  */
274                 BUG_ON(!(pwqt->task->state & TASK_DEAD));
275                 kfree_rcu(pwqt, rcu);
276         } else {
277                 preempt_notifier_unregister(&pwqt->notifier);
278                 kfree(pwqt);
279         }
280 }
281
282 static void pwqr_task_noop_sched_in(struct preempt_notifier *notifier, int cpu)
283 {
284 }
285
286 static void pwqr_task_noop_sched_out(struct preempt_notifier *notifier,
287                                     struct task_struct *next)
288 {
289 }
290
291 static void pwqr_task_blocked_sched_in(struct preempt_notifier *notifier, int cpu)
292 {
293         struct pwqr_task *pwqt = container_of(notifier, struct pwqr_task, notifier);
294         struct pwqr_sb   *sb   = pwqt->sb;
295         unsigned long flags;
296
297         if (unlikely(sb->state < 0)) {
298                 pwqr_task_detach(pwqt, sb);
299                 pwqr_task_release(pwqt, true);
300                 return;
301         }
302
303         pwqt->notifier.ops = &pwqr_preempt_running_ops;
304         pwqr_sb_lock_irqsave(sb, flags);
305         __pwqr_sb_update_state(sb, 1);
306         pwqr_sb_unlock_irqrestore(sb, flags);
307 }
308
309 static void pwqr_task_sched_out(struct preempt_notifier *notifier,
310                                struct task_struct *next)
311 {
312         struct pwqr_task   *pwqt = container_of(notifier, struct pwqr_task, notifier);
313         struct pwqr_sb     *sb   = pwqt->sb;
314         struct task_struct *p    = pwqt->task;
315
316         if (unlikely(p->state & TASK_DEAD) || unlikely(sb->state < 0)) {
317                 pwqr_task_detach(pwqt, sb);
318                 pwqr_task_release(pwqt, true);
319                 return;
320         }
321         if (p->state == 0 || (p->state & (__TASK_STOPPED | __TASK_TRACED)))
322                 return;
323
324         pwqt->notifier.ops = &pwqr_preempt_blocked_ops;
325         /* see preempt.h: irq are disabled for sched_out */
326         spin_lock(&sb->wqh.lock);
327         __pwqr_sb_update_state(sb, -1);
328         spin_unlock(&sb->wqh.lock);
329 }
330
331 static struct preempt_ops __read_mostly pwqr_preempt_noop_ops = {
332         .sched_in       = pwqr_task_noop_sched_in,
333         .sched_out      = pwqr_task_noop_sched_out,
334 };
335
336 static struct preempt_ops __read_mostly pwqr_preempt_running_ops = {
337         .sched_in       = pwqr_task_noop_sched_in,
338         .sched_out      = pwqr_task_sched_out,
339 };
340
341 static struct preempt_ops __read_mostly pwqr_preempt_blocked_ops = {
342         .sched_in       = pwqr_task_blocked_sched_in,
343         .sched_out      = pwqr_task_sched_out,
344 };
345
346 /*****************************************************************************
347  * file descriptor
348  */
349 static int pwqr_open(struct inode *inode, struct file *filp)
350 {
351         struct pwqr_sb *sb;
352
353         sb = pwqr_sb_create();
354         if (IS_ERR(sb))
355                 return PTR_ERR(sb);
356         filp->private_data = sb;
357         return 0;
358 }
359
360 static int pwqr_release(struct inode *inode, struct file *filp)
361 {
362         struct pwqr_sb *sb = filp->private_data;
363         unsigned long flags;
364
365         pwqr_sb_lock_irqsave(sb, flags);
366         sb->state = PWQR_STATE_DEAD;
367         pwqr_sb_unlock_irqrestore(sb, flags);
368         wake_up_all(&sb->wqh);
369         pwqr_sb_put(sb);
370         return 0;
371 }
372
373 static long
374 do_pwqr_wait(struct pwqr_sb *sb, struct pwqr_task *pwqt,
375              int is_wait, struct pwqr_ioc_wait __user *arg)
376 {
377         unsigned long flags;
378         struct pwqr_ioc_wait wait;
379         long rc = 0;
380         u32 uval;
381
382         preempt_notifier_unregister(&pwqt->notifier);
383
384         if (is_wait) {
385                 if (copy_from_user(&wait, arg, sizeof(wait))) {
386                         rc = -EFAULT;
387                         goto out;
388                 }
389                 if (unlikely((long)wait.pwqr_uaddr % sizeof(int) != 0)) {
390                         rc = -EINVAL;
391                         goto out;
392                 }
393         }
394
395         pwqr_sb_lock_irqsave(sb, flags);
396         if (sb->running + sb->waiting <= sb->concurrency) {
397                 if (is_wait) {
398                         while (probe_kernel_address(wait.pwqr_uaddr, uval)) {
399                                 pwqr_sb_unlock_irqrestore(sb, flags);
400                                 rc = get_user(uval, (u32 *)wait.pwqr_uaddr);
401                                 if (rc)
402                                         goto out;
403                                 pwqr_sb_lock_irqsave(sb, flags);
404                         }
405
406                         if (uval != (u32)wait.pwqr_ticket) {
407                                 rc = -EWOULDBLOCK;
408                                 goto out_unlock;
409                         }
410                 } else {
411                         goto out_unlock;
412                 }
413         }
414
415         /* @ see <wait_event_interruptible_exclusive_locked_irq> */
416         if (likely(sb->state >= 0)) {
417                 DEFINE_WAIT(__wait);
418
419                 __wait.flags |= WQ_FLAG_EXCLUSIVE;
420
421                 if (is_wait) {
422                         sb->waiting++;
423                         __add_wait_queue(&sb->wqh, &__wait);
424                 } else {
425                         sb->parked++;
426                         __add_wait_queue_tail(&sb->wqh, &__wait);
427                 }
428                 __pwqr_sb_update_state(sb, -1);
429                 set_current_state(TASK_INTERRUPTIBLE);
430
431                 do {
432                         if (sb->overcommit_wakes)
433                                 break;
434                         if (signal_pending(current)) {
435                                 rc = -ERESTARTSYS;
436                                 break;
437                         }
438                         spin_unlock_irq(&sb->wqh.lock);
439                         schedule();
440                         spin_lock_irq(&sb->wqh.lock);
441                         if (is_wait)
442                                 break;
443                         if (sb->running + sb->waiting < sb->concurrency)
444                                 break;
445                 } while (likely(sb->state >= 0));
446
447                 __remove_wait_queue(&sb->wqh, &__wait);
448                 __set_current_state(TASK_RUNNING);
449
450                 if (is_wait) {
451                         sb->waiting--;
452                 } else {
453                         sb->parked--;
454                 }
455                 __pwqr_sb_update_state(sb, 1);
456                 if (sb->overcommit_wakes)
457                         sb->overcommit_wakes--;
458                 if (sb->waiting + sb->running > sb->concurrency)
459                         rc = -EDQUOT;
460         }
461
462 out_unlock:
463         if (unlikely(sb->state < 0))
464                 rc = -EBADFD;
465         pwqr_sb_unlock_irqrestore(sb, flags);
466 out:
467         preempt_notifier_register(&pwqt->notifier);
468         return rc;
469 }
470
471 static long do_pwqr_unregister(struct pwqr_sb *sb, struct pwqr_task *pwqt)
472 {
473         if (!pwqt)
474                 return -EINVAL;
475         if (pwqt->sb != sb)
476                 return -ENOENT;
477         pwqr_task_detach(pwqt, sb);
478         pwqr_task_release(pwqt, false);
479         return 0;
480 }
481
482 static long do_pwqr_set_conc(struct pwqr_sb *sb, int conc)
483 {
484         long old_conc = sb->concurrency;
485         unsigned long flags;
486
487         pwqr_sb_lock_irqsave(sb, flags);
488         if (conc <= 0)
489                 conc = num_online_cpus();
490         if (conc != old_conc) {
491                 sb->concurrency = conc;
492                 __pwqr_sb_update_state(sb, 0);
493         }
494         pwqr_sb_unlock_irqrestore(sb, flags);
495
496         return old_conc;
497 }
498
499 static long do_pwqr_wake(struct pwqr_sb *sb, int oc, int count)
500 {
501         unsigned long flags;
502         int nwake;
503
504         if (count < 0)
505                 return -EINVAL;
506
507         pwqr_sb_lock_irqsave(sb, flags);
508
509         if (oc) {
510                 nwake = sb->waiting + sb->parked - sb->overcommit_wakes;
511                 if (count > nwake) {
512                         count = nwake;
513                 } else {
514                         nwake = count;
515                 }
516                 sb->overcommit_wakes += count;
517         } else if (sb->running + sb->overcommit_wakes < sb->concurrency) {
518                 nwake = sb->concurrency - sb->overcommit_wakes - sb->running;
519                 if (nwake > sb->waiting + sb->parked - sb->overcommit_wakes) {
520                         nwake = sb->waiting + sb->parked -
521                                 sb->overcommit_wakes;
522                 }
523                 if (count > nwake) {
524                         count = nwake;
525                 } else {
526                         nwake = count;
527                 }
528         } else {
529                 /*
530                  * This codepath deserves an explanation: waking the thread
531                  * "for real" would overcommit, though userspace KNOWS there
532                  * is at least one waiting thread. Such threads are threads
533                  * that are "quarantined".
534                  *
535                  * Quarantined threads are woken up one by one, to allow a
536                  * slow ramp down, trying to minimize "waiting" <-> "parked"
537                  * flip-flops, no matter how many wakes have been asked.
538                  *
539                  * Since releasing one quarantined thread will wake up a
540                  * thread that will (almost) straight go to parked mode, lie
541                  * to userland about the fact that we unblocked that thread,
542                  * and return 0.
543                  *
544                  * Though if we're already waking all waiting threads for
545                  * overcommitting jobs, well, we don't need that.
546                  */
547                 count = 0;
548                 nwake = sb->waiting > sb->overcommit_wakes;
549         }
550         while (nwake-- > 0)
551                 wake_up_locked(&sb->wqh);
552         pwqr_sb_unlock_irqrestore(sb, flags);
553
554         return count;
555 }
556
557 static long pwqr_ioctl(struct file *filp, unsigned command, unsigned long arg)
558 {
559         struct pwqr_sb     *sb   = filp->private_data;
560         struct task_struct *task = current;
561         struct pwqr_task   *pwqt;
562         int rc = 0;
563
564         switch (command) {
565         case PWQR_GET_CONC:
566                 return sb->concurrency;
567         case PWQR_SET_CONC:
568                 return do_pwqr_set_conc(sb, (int)arg);
569
570         case PWQR_WAKE:
571         case PWQR_WAKE_OC:
572                 return do_pwqr_wake(sb, command == PWQR_WAKE_OC, (int)arg);
573
574         case PWQR_WAIT:
575         case PWQR_PARK:
576         case PWQR_REGISTER:
577         case PWQR_UNREGISTER:
578                 break;
579         default:
580                 return -EINVAL;
581         }
582
583         pwqt = pwqr_task_find(task);
584         if (command == PWQR_UNREGISTER)
585                 return do_pwqr_unregister(sb, pwqt);
586
587         if (pwqt == NULL) {
588                 pwqt = pwqr_task_create(task);
589                 if (IS_ERR(pwqt))
590                         return PTR_ERR(pwqt);
591                 pwqr_task_attach(pwqt, sb);
592         } else if (unlikely(pwqt->sb != sb)) {
593                 pwqr_task_detach(pwqt, pwqt->sb);
594                 pwqr_task_attach(pwqt, sb);
595         }
596
597         switch (command) {
598         case PWQR_WAIT:
599                 rc = do_pwqr_wait(sb, pwqt, true, (struct pwqr_ioc_wait __user *)arg);
600                 break;
601         case PWQR_PARK:
602                 rc = do_pwqr_wait(sb, pwqt, false, NULL);
603                 break;
604         }
605
606         if (unlikely(sb->state < 0)) {
607                 pwqr_task_detach(pwqt, pwqt->sb);
608                 return -EBADFD;
609         }
610         return rc;
611 }
612
613 static const struct file_operations pwqr_dev_fops = {
614         .owner          = THIS_MODULE,
615         .open           = pwqr_open,
616         .release        = pwqr_release,
617         .unlocked_ioctl = pwqr_ioctl,
618 #ifdef CONFIG_COMPAT
619         .compat_ioctl   = pwqr_ioctl,
620 #endif
621 };
622
623 /*****************************************************************************
624  * module
625  */
626 static int __init pwqr_start(void)
627 {
628         int i;
629
630         for (i = 0; i < PWQR_HASH_SIZE; i++) {
631                 spin_lock_init(&pwqr_tasks_hash[i].lock);
632                 INIT_HLIST_HEAD(&pwqr_tasks_hash[i].tasks);
633         }
634
635         /* Register as a character device */
636         pwqr_major = register_chrdev(0, "pwqr", &pwqr_dev_fops);
637         if (pwqr_major < 0) {
638                 printk(KERN_ERR "pwqr: register_chrdev() failed\n");
639                 return pwqr_major;
640         }
641
642         /* Create a device node */
643         pwqr_class = class_create(THIS_MODULE, PWQR_DEVICE_NAME);
644         if (IS_ERR(pwqr_class)) {
645                 printk(KERN_ERR "pwqr: Error creating raw class\n");
646                 unregister_chrdev(pwqr_major, PWQR_DEVICE_NAME);
647                 return PTR_ERR(pwqr_class);
648         }
649         device_create(pwqr_class, NULL, MKDEV(pwqr_major, 0), NULL, PWQR_DEVICE_NAME);
650         printk(KERN_INFO "pwqr: PThreads Work Queues Regulator v1 loaded");
651         return 0;
652 }
653
654 static void __exit pwqr_end(void)
655 {
656         rcu_barrier();
657         device_destroy(pwqr_class, MKDEV(pwqr_major, 0));
658         class_destroy(pwqr_class);
659         unregister_chrdev(pwqr_major, PWQR_DEVICE_NAME);
660 }
661
662 module_init(pwqr_start);
663 module_exit(pwqr_end);
664
665 MODULE_LICENSE("GPL");
666 MODULE_AUTHOR("Pierre Habouzit <pierre.habouzit@intersec.com>");
667 MODULE_DESCRIPTION("PThreads Work Queues Regulator");
668
669 // vim:noet:sw=8:cinoptions+=\:0,L-1,=1s: