#define CREATE_TRACE_POINTS
#include <trace/events/mmap_lock.h>
#include <linux/mm.h>
#include <linux/cgroup.h>
#include <linux/memcontrol.h>
#include <linux/mmap_lock.h>
#include <linux/mutex.h>
#include <linux/percpu.h>
#include <linux/rcupdate.h>
#include <linux/smp.h>
#include <linux/trace_events.h>
#include <linux/local_lock.h>
EXPORT_TRACEPOINT_SYMBOL(mmap_lock_start_locking);
EXPORT_TRACEPOINT_SYMBOL(mmap_lock_acquire_returned);
EXPORT_TRACEPOINT_SYMBOL(mmap_lock_released);
#ifdef CONFIG_TRACING
void __mmap_lock_do_trace_start_locking(struct mm_struct *mm, bool write)
{
trace_mmap_lock_start_locking(mm, write);
}
EXPORT_SYMBOL(__mmap_lock_do_trace_start_locking);
void __mmap_lock_do_trace_acquire_returned(struct mm_struct *mm, bool write,
bool success)
{
trace_mmap_lock_acquire_returned(mm, write, success);
}
EXPORT_SYMBOL(__mmap_lock_do_trace_acquire_returned);
void __mmap_lock_do_trace_released(struct mm_struct *mm, bool write)
{
trace_mmap_lock_released(mm, write);
}
EXPORT_SYMBOL(__mmap_lock_do_trace_released);
#endif
#ifdef CONFIG_MMU
#ifdef CONFIG_PER_VMA_LOCK
struct vma_exclude_readers_state {
struct vm_area_struct *vma;
int state;
bool detaching;
bool detached;
bool exclusive;
};
static void __vma_end_exclude_readers(struct vma_exclude_readers_state *ves)
{
struct vm_area_struct *vma = ves->vma;
VM_WARN_ON_ONCE(ves->detached);
ves->detached = refcount_sub_and_test(VM_REFCNT_EXCLUDE_READERS_FLAG,
&vma->vm_refcnt);
__vma_lockdep_release_exclusive(vma);
}
static unsigned int get_target_refcnt(struct vma_exclude_readers_state *ves)
{
const unsigned int tgt = ves->detaching ? 0 : 1;
return tgt | VM_REFCNT_EXCLUDE_READERS_FLAG;
}
static int __vma_start_exclude_readers(struct vma_exclude_readers_state *ves)
{
struct vm_area_struct *vma = ves->vma;
unsigned int tgt_refcnt = get_target_refcnt(ves);
int err = 0;
mmap_assert_write_locked(vma->vm_mm);
if (!refcount_add_not_zero(VM_REFCNT_EXCLUDE_READERS_FLAG, &vma->vm_refcnt)) {
ves->detached = true;
return 0;
}
__vma_lockdep_acquire_exclusive(vma);
err = rcuwait_wait_event(&vma->vm_mm->vma_writer_wait,
refcount_read(&vma->vm_refcnt) == tgt_refcnt,
ves->state);
if (err) {
__vma_end_exclude_readers(ves);
return err;
}
__vma_lockdep_stat_mark_acquired(vma);
ves->exclusive = true;
return 0;
}
int __vma_start_write(struct vm_area_struct *vma, int state)
{
const unsigned int mm_lock_seq = __vma_raw_mm_seqnum(vma);
struct vma_exclude_readers_state ves = {
.vma = vma,
.state = state,
};
int err;
err = __vma_start_exclude_readers(&ves);
if (err) {
WARN_ON_ONCE(ves.detached);
return err;
}
WRITE_ONCE(vma->vm_lock_seq, mm_lock_seq);
if (ves.exclusive) {
__vma_end_exclude_readers(&ves);
WARN_ON_ONCE(ves.detached);
}
return 0;
}
EXPORT_SYMBOL_GPL(__vma_start_write);
void __vma_exclude_readers_for_detach(struct vm_area_struct *vma)
{
struct vma_exclude_readers_state ves = {
.vma = vma,
.state = TASK_UNINTERRUPTIBLE,
.detaching = true,
};
int err;
err = __vma_start_exclude_readers(&ves);
if (!err && ves.exclusive) {
__vma_end_exclude_readers(&ves);
}
WARN_ON_ONCE(!ves.detached);
}
static inline struct vm_area_struct *vma_start_read(struct mm_struct *mm,
struct vm_area_struct *vma)
{
struct mm_struct *other_mm;
int oldcnt;
RCU_LOCKDEP_WARN(!rcu_read_lock_held(), "no rcu lock held");
if (READ_ONCE(vma->vm_lock_seq) == READ_ONCE(mm->mm_lock_seq.sequence)) {
vma = NULL;
goto err;
}
if (unlikely(!__refcount_inc_not_zero_limited_acquire(&vma->vm_refcnt, &oldcnt,
VM_REFCNT_LIMIT))) {
vma = oldcnt ? NULL : ERR_PTR(-EAGAIN);
goto err;
}
__vma_lockdep_acquire_read(vma);
if (unlikely(vma->vm_mm != mm))
goto err_unstable;
if (unlikely(vma->vm_lock_seq == raw_read_seqcount(&mm->mm_lock_seq))) {
vma_refcount_put(vma);
vma = NULL;
goto err;
}
return vma;
err:
rcu_read_unlock();
return vma;
err_unstable:
other_mm = vma->vm_mm;
rcu_read_unlock();
mmgrab(other_mm);
vma_refcount_put(vma);
mmdrop(other_mm);
return NULL;
}
struct vm_area_struct *lock_vma_under_rcu(struct mm_struct *mm,
unsigned long address)
{
MA_STATE(mas, &mm->mm_mt, address, address);
struct vm_area_struct *vma;
retry:
rcu_read_lock();
vma = mas_walk(&mas);
if (!vma) {
rcu_read_unlock();
goto inval;
}
vma = vma_start_read(mm, vma);
if (IS_ERR_OR_NULL(vma)) {
if (PTR_ERR(vma) == -EAGAIN) {
count_vm_vma_lock_event(VMA_LOCK_MISS);
mas_set(&mas, address);
goto retry;
}
goto inval;
}
rcu_read_unlock();
if (unlikely(address < vma->vm_start || address >= vma->vm_end)) {
vma_end_read(vma);
goto inval;
}
return vma;
inval:
count_vm_vma_lock_event(VMA_LOCK_ABORT);
return NULL;
}
static struct vm_area_struct *lock_next_vma_under_mmap_lock(struct mm_struct *mm,
struct vma_iterator *vmi,
unsigned long from_addr)
{
struct vm_area_struct *vma;
int ret;
ret = mmap_read_lock_killable(mm);
if (ret)
return ERR_PTR(ret);
vma_iter_set(vmi, from_addr);
vma = vma_next(vmi);
if (vma) {
if (unlikely(!vma_start_read_locked(vma)))
vma = ERR_PTR(-EAGAIN);
}
mmap_read_unlock(mm);
return vma;
}
struct vm_area_struct *lock_next_vma(struct mm_struct *mm,
struct vma_iterator *vmi,
unsigned long from_addr)
{
struct vm_area_struct *vma;
unsigned int mm_wr_seq;
bool mmap_unlocked;
RCU_LOCKDEP_WARN(!rcu_read_lock_held(), "no rcu read lock held");
retry:
mmap_unlocked = mmap_lock_speculate_try_begin(mm, &mm_wr_seq);
vma = vma_next(vmi);
if (!vma)
return NULL;
vma = vma_start_read(mm, vma);
if (IS_ERR_OR_NULL(vma)) {
if (PTR_ERR(vma) == -EAGAIN) {
rcu_read_lock();
vma_iter_set(vmi, from_addr);
goto retry;
}
goto fallback;
}
if (unlikely(from_addr >= vma->vm_end))
goto fallback_unlock;
if (from_addr < vma->vm_start) {
if (!mmap_unlocked || mmap_lock_speculate_retry(mm, mm_wr_seq)) {
vma_iter_set(vmi, from_addr);
if (vma != vma_next(vmi))
goto fallback_unlock;
}
}
return vma;
fallback_unlock:
rcu_read_unlock();
vma_end_read(vma);
fallback:
vma = lock_next_vma_under_mmap_lock(mm, vmi, from_addr);
rcu_read_lock();
vma_iter_set(vmi, IS_ERR_OR_NULL(vma) ? from_addr : vma->vm_end);
return vma;
}
#endif
#ifdef CONFIG_LOCK_MM_AND_FIND_VMA
#include <linux/extable.h>
static inline bool get_mmap_lock_carefully(struct mm_struct *mm, struct pt_regs *regs)
{
if (likely(mmap_read_trylock(mm)))
return true;
if (regs && !user_mode(regs)) {
unsigned long ip = exception_ip(regs);
if (!search_exception_tables(ip))
return false;
}
return !mmap_read_lock_killable(mm);
}
static inline bool mmap_upgrade_trylock(struct mm_struct *mm)
{
return false;
}
static inline bool upgrade_mmap_lock_carefully(struct mm_struct *mm, struct pt_regs *regs)
{
mmap_read_unlock(mm);
if (regs && !user_mode(regs)) {
unsigned long ip = exception_ip(regs);
if (!search_exception_tables(ip))
return false;
}
return !mmap_write_lock_killable(mm);
}
struct vm_area_struct *lock_mm_and_find_vma(struct mm_struct *mm,
unsigned long addr, struct pt_regs *regs)
{
struct vm_area_struct *vma;
if (!get_mmap_lock_carefully(mm, regs))
return NULL;
vma = find_vma(mm, addr);
if (likely(vma && (vma->vm_start <= addr)))
return vma;
if (!vma || !(vma->vm_flags & VM_GROWSDOWN)) {
mmap_read_unlock(mm);
return NULL;
}
if (!mmap_upgrade_trylock(mm)) {
if (!upgrade_mmap_lock_carefully(mm, regs))
return NULL;
vma = find_vma(mm, addr);
if (!vma)
goto fail;
if (vma->vm_start <= addr)
goto success;
if (!(vma->vm_flags & VM_GROWSDOWN))
goto fail;
}
if (expand_stack_locked(vma, addr))
goto fail;
success:
mmap_write_downgrade(mm);
return vma;
fail:
mmap_write_unlock(mm);
return NULL;
}
#endif
#else
struct vm_area_struct *lock_mm_and_find_vma(struct mm_struct *mm,
unsigned long addr, struct pt_regs *regs)
{
struct vm_area_struct *vma;
mmap_read_lock(mm);
vma = vma_lookup(mm, addr);
if (!vma)
mmap_read_unlock(mm);
return vma;
}
#endif