/**
@file     exited_tasks.c
@brief    Cache exited tasks
@details  Copyright (c) 2025 Acronis International GmbH
@author   Bruce Wang (bruce.wang@acronis.com)
@since    $Id: $
*/

#include "debug.h"
#include "exec_event.h"
#include "exited_tasks.h"
#include "ftrace_hooks/ftrace_events.h"
#include "ftrace_hooks/reg_tools.h"
#include "memory.h"
#include "syscall_common.h"

#include <linux/atomic.h>
#include <linux/fs.h>
#include <linux/namei.h>
#include <linux/spinlock.h>

#define EXITED_TASKS_LRU_SIZE 2048
#define EXITED_TASKS_LRU_CLEAN_SIZE 100
#define EXITED_TASKS_EXPIRE_TIME_MS (30 * 1000) // 30s

#ifndef list_first_entry_or_null
#define list_first_entry_or_null(ptr, type, member) (list_empty(ptr) ? NULL : list_first_entry(ptr, type, member))
#endif

static KMEM_STRUCT_CACHE_DECLARE(exited_task_node);

static spinlock_t exited_tasks_list_lock;
static struct mutex exited_tasks_list_startstop_mutex;
static bool exited_tasks_listening = false;
static struct list_head exited_tasks_list;
static uint32_t exited_tasks_list_size = 0;

void exited_task_put(exited_task_node_t *node);

static void exited_tasks_add(struct task_struct *task)
{
    int ret;
    exited_task_node_t *node;
    uint64_t unique_pid = make_unique_pid(task);

    if (!READ_ONCE(exited_tasks_listening))
    {
        DPRINTF("Exited tasks table is not enabled");
        return;
    }

    node = KMEM_NEW0(exited_task_node);
    if (!node)
    {
        EPRINTF("Failed to allocate memory for exited task node");
        return;
    }

    ret = process_info_return_msg_new(&node->msg_varsized, 0, task);
    if (ret != 0)
    {
        EPRINTF("Failed to create msg_varsized for exited task PID %d", task->pid);
        KMEM_DELETE(exited_task_node, node);
        return;
    }

    node->unique_pid = unique_pid;
    node->pid = task->pid;
    node->exit_time = jiffies;
    atomic_set(&node->refcount, 1);

    {
        struct list_head tmp_list;
        INIT_LIST_HEAD(&tmp_list);

        spin_lock(&exited_tasks_list_lock);
        if (exited_tasks_listening)
        {
            list_add_tail(&node->list_node, &exited_tasks_list);
            exited_tasks_list_size++;
            while (1)
            {
                node = list_first_entry_or_null(&exited_tasks_list, exited_task_node_t, list_node);
                if (!node)
                {
                    break;
                }
                // stop when EXITED_TASKS_LRU_CLEAN_SIZE of nodes are deleted and the oldest node is not expired
                if (jiffies - node->exit_time < msecs_to_jiffies(EXITED_TASKS_EXPIRE_TIME_MS) &&
                    EXITED_TASKS_LRU_SIZE - exited_tasks_list_size > EXITED_TASKS_LRU_CLEAN_SIZE)
                {
                    break;
                }
                list_del(&node->list_node);
                list_add_tail(&node->list_node, &tmp_list);
                exited_tasks_list_size--;
            }
        }
        else
        {
            // exited tasks table is disabled during adding, just free the node
            list_add_tail(&node->list_node, &tmp_list);
        }
        spin_unlock(&exited_tasks_list_lock);

        while (!list_empty(&tmp_list))
        {
            node = list_first_entry(&tmp_list, exited_task_node_t, list_node);
            list_del(&node->list_node);
            exited_task_put(node);
        }
    }

    DPRINTF("Added exited task node for PID %d %lu", task->pid, unique_pid);

    return;
}

int exited_task_collect(transport_t *transport)
{
    exited_task_node_t **nodes = NULL;
    uint32_t nodes_size = 0;

    exited_task_node_t **nodes_to_free = NULL;
    int indices_to_free = 0;
    int size_to_free = 0;

    uint32_t list_size_to_collect;

    DPRINTF("Collect extied tasks for transport %p", transport);
    list_size_to_collect = READ_ONCE(exited_tasks_list_size);
    if (!list_size_to_collect)
    {
        DPRINTF("No exited tasks to collect");
        return 0;
    }

    nodes = mem_alloc0(list_size_to_collect * sizeof(exited_task_node_t *));
    if (!nodes)
    {
        EPRINTF("Failed to allocate memory for exited tasks collection");
        return -1;
    }
    spin_lock(&exited_tasks_list_lock);
    {
        int i = 0;
        exited_task_node_t *node;
        list_for_each_entry(node, &exited_tasks_list, list_node)
        {
            if (nodes_size >= list_size_to_collect)
            {
                break;
            }

            DPRINTF("Collect exited task node for PID %d %lu", node->pid, node->unique_pid);
            atomic_inc(&node->refcount);
            nodes[i++] = node;
            nodes_size++;
        }
    }
    spin_unlock(&exited_tasks_list_lock);

    // Place the collected nodes to transport, potentially dropping already existing iteration
    spin_lock(&transport->exited_processes.lock);
    nodes_to_free = transport->exited_processes.nodes;
    indices_to_free = transport->exited_processes.index;
    size_to_free = transport->exited_processes.size;

    transport->exited_processes.index = 0;
    transport->exited_processes.size = nodes_size;
    transport->exited_processes.nodes = nodes;
    spin_unlock(&transport->exited_processes.lock);

    if (nodes_to_free) {
        while (indices_to_free < size_to_free)
        {
            exited_task_put(nodes_to_free[indices_to_free++]);
        }
        mem_free(nodes_to_free);
    }

    return 0;
}

void free_collected_exited_task(transport_t *transport)
{
    exited_task_node_t **nodes_to_free = NULL;
    int idx;
    int size;
    DPRINTF("Free collected exited tasks for transport %p", transport);

    spin_lock(&transport->exited_processes.lock);
    if (!transport->exited_processes.nodes)
    {
        DPRINTF("Exited tasks is not collected for %p", transport);
        spin_unlock(&transport->exited_processes.lock);
        return;
    }

    nodes_to_free = transport->exited_processes.nodes;
    idx = transport->exited_processes.index;
    size = transport->exited_processes.size;
    spin_unlock(&transport->exited_processes.lock);

    while (idx < size)
    {
        exited_task_put(nodes_to_free[idx++]);
    }
    mem_free(nodes_to_free);
}

exited_task_node_t *exited_task_get(transport_t *transport)
{
    exited_task_node_t *node = NULL;
    exited_task_node_t **nodes_to_free = NULL;
    spin_lock(&transport->exited_processes.lock);
    if (!transport->exited_processes.nodes)
    {
        DPRINTF("Exited tasks collection is not ready");
        spin_unlock(&transport->exited_processes.lock);
        return NULL;
    }
    if (transport->exited_processes.index < transport->exited_processes.size)
    {
        node = transport->exited_processes.nodes[transport->exited_processes.index];
        transport->exited_processes.nodes[transport->exited_processes.index] = NULL;
        transport->exited_processes.index++;
    } else {
        nodes_to_free = transport->exited_processes.nodes;
        transport->exited_processes.nodes = NULL;
    }
    spin_unlock(&transport->exited_processes.lock);

    if (node)
    {
        // the node's refcount is already increased in exited_task_collect
        DPRINTF("Found exited task node for PID %d %lu", node->pid, node->unique_pid);
    }
    else
    {
        DPRINTF("Exited tasks is all sent, free collected exited tasks for transport %p", transport);
    }

    if (nodes_to_free) {
        mem_free(nodes_to_free);
    }

    return node;
}

static void free_exited_task_node(exited_task_node_t *node)
{
    DPRINTF("Cleaning exited task node for PID %d", node->pid);
    if (node->msg_varsized.data.heap.ptr)
    {
        msg_varsized_uninit(&node->msg_varsized);
    }
    KMEM_DELETE(exited_task_node, node);
}

void exited_task_put(exited_task_node_t *node)
{
    if (!node)
    {
        return;
    }

    if (atomic_dec_and_test(&node->refcount))
    {
        free_exited_task_node(node);
    }
}

static void exit_entry_handler(unsigned long ip, unsigned long parent_ip, struct ftrace_ops *op, struct pt_regs *regs)
{
    struct task_struct *tsk = current;
    (void)ip;
    (void)parent_ip;
    (void)op;
    (void)regs;

    if (tsk->pid == tsk->tgid)
    {
        exited_tasks_add(tsk);
    }
}

static fp_probe_t do_exit_fprobe = {
    .ops = {
        .func = (ftrace_func_t)exit_entry_handler,
        .flags = 0,
    },
    .fn = 0,
    .name = "do_exit",
    .registered = false,
};

static int unlink_path(struct path *path)
{
    if (!S_ISREG(path->dentry->d_inode->i_mode))
    {
        return -EINVAL;
    }
#ifdef VFS_UNLINK_WITH_MNT_IDMAP
    // https://elixir.bootlin.com/linux/v6.16.4/source/fs/smb/server/vfs.c#L610
    return vfs_unlink(mnt_idmap(path->mnt), path->dentry->d_parent->d_inode, path->dentry, NULL);
#elif defined(VFS_UNLINK_WITH_USER_NAMESPACE)
    // https://elixir.bootlin.com/linux/v5.19.17/source/fs/ksmbd/vfs.c#L625
    return vfs_unlink(mnt_user_ns(path->mnt), path->dentry->d_parent->d_inode, path->dentry, NULL);
#elif defined(VFS_UNLINK_WITH_DELEGATED_INODE)
    // https://elixir.bootlin.com/linux/v3.18.139/source/fs/nfsd/vfs.c#L1751
    return vfs_unlink(path->dentry->d_parent->d_inode, path->dentry, NULL);
#else
    // https://elixir.bootlin.com/linux/v2.6.39.4/source/fs/nfsd/vfs.c#L1838
    return vfs_unlink(path->dentry->d_parent->d_inode, path->dentry);
#endif
}

static int kernel_path_exists(const char *path)
{
    struct path kernel_path;
    int ret;

    ret = kern_path(path, 0, &kernel_path);
    if (ret)
    {
        return ret;
    }

    ret = unlink_path(&kernel_path);
    if (ret)
    {
        EPRINTF("Failed to unlink %s, error %d", path, ret);
    }
    else
    {
        DPRINTF("Unlinked %s", path);
    }

    path_put(&kernel_path);

    return 0;
}

int exited_tasks_init(void)
{
    KMEM_STRUCT_CACHE_NAME(exited_task_node) = NULL;
    if (!KMEM_STRUCT_CACHE_INIT(exited_task_node, 0, NULL))
    {
        EPRINTF("Failed to create exited_task_node cache");
        return -1;
    }

    INIT_LIST_HEAD(&exited_tasks_list);
    spin_lock_init(&exited_tasks_list_lock);
    mutex_init(&exited_tasks_list_startstop_mutex);
    DPRINTF("Exited tasks table initialized");

    if (0 == kernel_path_exists(EARLY_PROC_TRACKER_PATH))
    {
        exited_tasks_start();
    }

    return 0;
}

int exited_tasks_start(void)
{
    int ret;
    mutex_lock(&exited_tasks_list_startstop_mutex);

    if (exited_tasks_listening)
    {
        DPRINTF("Exited tasks table is already enabled");
        mutex_unlock(&exited_tasks_list_startstop_mutex);
        return 0;
    }
    WRITE_ONCE(exited_tasks_listening, true);

    ret = register_ftrace_post_event(&do_exit_fprobe, &do_exit_fprobe.ops);
    mutex_unlock(&exited_tasks_list_startstop_mutex);

    if (ret == 0)
    {
        DPRINTF("Exited tasks table started");
        return 0;
    }
    else
    {
        EPRINTF("Failed to register do_exit_fprobe");
        return -1;
    }
}

static void clear_exited_tasks(void)
{
    exited_task_node_t *node = NULL;
    uint32_t remaining_size = 0;
    struct list_head tmp_list;
    INIT_LIST_HEAD(&tmp_list);

    spin_lock(&exited_tasks_list_lock);
    while (!list_empty(&exited_tasks_list))
    {
        node = list_first_entry(&exited_tasks_list, exited_task_node_t, list_node);
        list_del(&node->list_node);
        exited_tasks_list_size--;
        list_add_tail(&node->list_node, &tmp_list);
        remaining_size += 1;
    }
    spin_unlock(&exited_tasks_list_lock);

    DPRINTF("Clearing exited tasks table, remaining size: %u", remaining_size);
    while (!list_empty(&tmp_list))
    {
        node = list_first_entry(&tmp_list, exited_task_node_t, list_node);
        list_del(&node->list_node);
        exited_task_put(node);
    }
}

void exited_tasks_stop(void)
{
    mutex_lock(&exited_tasks_list_startstop_mutex);
    if (!exited_tasks_listening)
    {
        DPRINTF("Exited tasks table is not enabled");
        mutex_unlock(&exited_tasks_list_startstop_mutex);
        return;
    }
    WRITE_ONCE(exited_tasks_listening, false);
    unregister_ftrace_post_event(&do_exit_fprobe);
    clear_exited_tasks();
    mutex_unlock(&exited_tasks_list_startstop_mutex);
}

void exited_tasks_deinit(void)
{
    exited_tasks_stop();
    KMEM_STRUCT_CACHE_DEINIT(exited_task_node);
    DPRINTF("Exited tasks table deinitialized");
}
