/**
@file     udp_socket_manager.c
@brief    Manager for UDP socket flow
@details  Copyright (c) 2025 Acronis International GmbH
@author   Denis Kopyrin (denis.kopyrin@acronis.com)
@since    $Id: $
*/

#include "udp_socket_manager.h"

#include "hashtable_compat.h"
#include "hash_fast.h"
#include "memory.h"
#include "net_compat.h"
#include "net_events.h"

#include <linux/in.h>
#include <linux/in6.h>
#include <linux/jiffies.h>
#include <linux/list.h>
#include <net/ipv6.h>
#include <net/sock.h>

#define TABLE_SIZE_BITS 15
#define TABLE_SIZE (1 << (TABLE_SIZE_BITS - 1)) // 16384

#ifndef list_first_entry_or_null
#define list_first_entry_or_null(ptr, type, member) (list_empty(ptr) ? NULL : list_first_entry(ptr, type, member))
#endif

#ifdef KERNEL_MOCK
#include "mock/mock.h"
#endif

#ifndef SOCKFS_MAGIC
#define SOCKFS_MAGIC		0x534F434B
#endif

// 30 seconds
#define TTL msecs_to_jiffies(30000)

typedef struct
{
	size_t data_sz;
	// used as a key from socket
	const struct inode *inode;
	char data[];
} hashtable_udp_socket_key_t;

// I am being overly explicit here to avoid any potential issues with padding
typedef struct PACKED
{
	uint32_t remote;
	uint32_t local;
	uint16_t remote_port;
	uint16_t local_port;
} hashtable_udp_socket_key_data_ipv4_t;

typedef struct PACKED
{
	uint32_t remote[4];
	uint32_t local[4];
	uint16_t remote_port;
	uint16_t local_port;
} hashtable_udp_socket_key_data_ipv6_t;

typedef struct
{
	hashtable_udp_socket_key_t header;
	union
	{
		hashtable_udp_socket_key_data_ipv4_t ipv4;
		hashtable_udp_socket_key_data_ipv6_t ipv6;
	};
} hashtable_udp_socket_key_common_t;

typedef struct
{
	atomic_t refcount;

	// hashed by inode
	struct hlist_node hash_inode_node;
	// hashed by inode + key
	struct hlist_node hash_addr_node;

	bool lru_list_inserted;
	unsigned long lru_deadline;
	struct list_head lru_list_node;

	struct rcu_head rcu;

	hashtable_udp_socket_key_t key;
} hashtable_udp_socket_node_t;

typedef struct udp_socket_manager
{
	struct mutex table_writer_lock;
	bool active;
	ssize_t flows_count;
	// Hashtable for all the flows that were seen by UDP.
	// Note that this is moreso a hack because normally UDP does not
	// contain a notion of a "flow" like TCP does but to avoid
	// spamming the client with the same messages, "flow" is introduced
	DECLARE_HASHTABLE(seen_flows_addr_hashtable , TABLE_SIZE_BITS);
	// For 'inode' mind that they mimic the same entries as 'addr' hashtable
	// but there can be repeated inodes - sendmsg might be called with multiple address.
	// For lookups normally 'seen_flows_addr_hashtable' should be used.
	DECLARE_HASHTABLE(seen_flows_inode_hashtable, TABLE_SIZE_BITS);
	struct list_head seen_flows_lru_list;
} udp_socket_manager_t;

static udp_socket_manager_t* global_udp_socket_manager;

// MARK: UDP socket node
static hashtable_udp_socket_node_t* node_alloc(const hashtable_udp_socket_key_t* key)
{
	hashtable_udp_socket_node_t* node = mem_alloc(sizeof(hashtable_udp_socket_node_t) + key->data_sz);
	if (!node)
		return NULL;

	atomic_set(&node->refcount, 1);
	node->lru_list_inserted = false;

	// Linux kernel will complain that I am overwriting the key insecurely so copy carefully
	node->key.data_sz = key->data_sz;
	node->key.inode = key->inode;
	memcpy(node->key.data, key->data, key->data_sz);

	return node;
}

static void node_rcu_free(struct rcu_head *rcu)
{
	hashtable_udp_socket_node_t *node = container_of(rcu, hashtable_udp_socket_node_t, rcu);
	mem_free(node);
}

static void node_put(hashtable_udp_socket_node_t* node) {
	if (atomic_dec_and_test(&node->refcount))
		call_rcu(&node->rcu, node_rcu_free);
}

static bool pack4(const struct sockaddr_in* sa, int sa_len, void* out_addr, void* out_port)
{
	if (sa_len < (int) sizeof(struct sockaddr_in))
		return false;

	__builtin_memcpy(out_addr, &sa->sin_addr, sizeof(sa->sin_addr));
	__builtin_memcpy(out_port, &sa->sin_port, sizeof(sa->sin_port));
	return true;
}

static bool pack6(const struct sockaddr_in6* sa6, int sa_len, void* out_addr, void* out_port)
{
	if (sa_len < SIN6_LEN_RFC2133)
		return false;

	__builtin_memcpy(out_addr, &sa6->sin6_addr, sizeof(sa6->sin6_addr));
	__builtin_memcpy(out_port, &sa6->sin6_port, sizeof(sa6->sin6_port));
	return true;
}

static bool pack_remote(hashtable_udp_socket_key_common_t *key, int family, void* sa, int sa_len)
{
	switch (family)
	{
		case AF_INET:
			return pack4((struct sockaddr_in*) sa , sa_len, &key->ipv4.remote, &key->ipv4.remote_port);
		case AF_INET6:
			return pack6((struct sockaddr_in6*) sa, sa_len, &key->ipv6.remote, &key->ipv6.remote_port);
		default:
			return false;
	}
}

static bool pack_local(hashtable_udp_socket_key_common_t *key, int family, void* sa, int sa_len)
{
	switch (family)
	{
		case AF_INET:
			return pack4((struct sockaddr_in*) sa , sa_len, &key->ipv4.local, &key->ipv4.local_port);
		case AF_INET6:
			return pack6((struct sockaddr_in6*) sa, sa_len, &key->ipv6.local, &key->ipv6.local_port);
		default:
			return false;
	}
}

// MARK: UDP socket key
static bool socket_make_key(hashtable_udp_socket_key_common_t *key, struct socket *sock, struct msghdr *msg)
{
	int family = sock->sk->sk_family;
	switch (family)
	{
		case AF_INET:
			key->header.data_sz = sizeof(hashtable_udp_socket_key_data_ipv4_t);
			break;
		case AF_INET6:
			key->header.data_sz = sizeof(hashtable_udp_socket_key_data_ipv6_t);
			break;
		default:
			return false;
	}

	key->header.inode = SOCK_INODE(sock);

	if (msg->msg_name) {
		if (!pack_remote(key, family, msg->msg_name, msg->msg_namelen))
			return false;
	} else {
		struct sockaddr_storage remote_storage_addr;
		int remote_addr_len = sock_to_addr(sock, &remote_storage_addr, PEER_REMOTE_ALWAYS);
		if (!pack_remote(key, family, &remote_storage_addr, remote_addr_len))
			return false;
	}

	{
		struct sockaddr_storage local_storage_addr;
		int local_addr_len = sock_to_addr(sock, &local_storage_addr, PEER_LOCAL);
		if (!pack_local(key, family, &local_storage_addr, local_addr_len))
			return false;
	}

	return true;
}

static bool key_equal(const hashtable_udp_socket_key_t* k1, const hashtable_udp_socket_key_t* k2)
{
	return k1->data_sz == k2->data_sz && k1->inode == k2->inode && 0 == memcmp(k1->data, k2->data, k1->data_sz);
}

static int key_hash(const hashtable_udp_socket_key_t* key)
{
	return murmur_hash(&key->inode, key->data_sz + sizeof(key->inode)) >> (64 - TABLE_SIZE_BITS);
}

static int inode_hash(const struct inode* inode)
{
	return moremur_hash((uint64_t) inode, TABLE_SIZE_BITS);
}

// MARK: UDP socket table

// Note that multiple inodes might be inserted into the same hash.
// We just need to check if any exists
static bool inode_is_inserted_rcu(int hash, const struct inode* inode) {
	hashtable_udp_socket_node_t *search_node;
	hlist_for_each_entry_rcu(search_node, &global_udp_socket_manager->seen_flows_inode_hashtable[hash], hash_inode_node) {
		if (search_node->key.inode == inode) {
			return true;
		}
	}
	return false;
}

static hashtable_udp_socket_node_t* find_ref_rcu(int hash, const hashtable_udp_socket_key_t* key) {
	hashtable_udp_socket_node_t *search_node;
	hlist_for_each_entry_rcu(search_node, &global_udp_socket_manager->seen_flows_addr_hashtable[hash], hash_addr_node) {
		if (!key_equal(&search_node->key, key))
			continue;

		if (atomic_inc_not_zero(&search_node->refcount))
			return search_node;
		else
			return NULL;
	}

	return NULL;
}

static hashtable_udp_socket_node_t* find(int hash, const hashtable_udp_socket_key_t* key) {
	hashtable_udp_socket_node_t *search_node;
	hlist_for_each_entry(search_node, &global_udp_socket_manager->seen_flows_addr_hashtable[hash], hash_addr_node) {
		if (key_equal(&search_node->key, key))
			return search_node;
	}

	return NULL;
}

static void erase_impl(hashtable_udp_socket_node_t *node)
{
	hash_del_rcu(&node->hash_inode_node);
	hash_del_rcu(&node->hash_addr_node);
	list_del(&node->lru_list_node);
	node->lru_list_inserted = false;
	global_udp_socket_manager->flows_count--;
	node_put(node);
}

static void inode_erase_all_impl(int hash, const struct inode* inode) {
	hashtable_udp_socket_node_t *search_node;
	struct hlist_node* tmp;
	hlist_for_each_entry_safe(search_node, tmp, &global_udp_socket_manager->seen_flows_inode_hashtable[hash], hash_inode_node) {
		if (search_node->key.inode == inode)
			erase_impl(search_node);
	}
}

static void refresh_impl(hashtable_udp_socket_node_t* node)
{
	if (node->lru_list_inserted) {
		node->lru_deadline = jiffies + TTL;
		list_del(&node->lru_list_node);
		list_add_tail(&node->lru_list_node, &global_udp_socket_manager->seen_flows_lru_list);
	}
}

// MARK: UDP socket manager
int udp_socket_manager_init(void)
{
	global_udp_socket_manager = vmem_alloc(sizeof(udp_socket_manager_t));
	if (!global_udp_socket_manager)
		return -ENOMEM;

	mutex_init(&global_udp_socket_manager->table_writer_lock);
	global_udp_socket_manager->active = false;
	global_udp_socket_manager->flows_count = 0;
	hash_init(global_udp_socket_manager->seen_flows_inode_hashtable);
	hash_init(global_udp_socket_manager->seen_flows_addr_hashtable);
	INIT_LIST_HEAD(&global_udp_socket_manager->seen_flows_lru_list);

	return 0;
}

void udp_socket_manager_deinit(void)
{
	if (!global_udp_socket_manager)
		return;

	vmem_free(global_udp_socket_manager);
}

void udp_socket_manager_activate(void)
{
	mutex_lock(&global_udp_socket_manager->table_writer_lock);
	global_udp_socket_manager->active = true;
	mutex_unlock(&global_udp_socket_manager->table_writer_lock);
}

void udp_socket_manager_deactivate(void)
{
	mutex_lock(&global_udp_socket_manager->table_writer_lock);
	if (global_udp_socket_manager->active) {
		while (1)
		{
			hashtable_udp_socket_node_t *node = list_first_entry_or_null(&global_udp_socket_manager->seen_flows_lru_list, hashtable_udp_socket_node_t, lru_list_node);
			if (!node)
				break;

			erase_impl(node);
		}

		global_udp_socket_manager->active = false;
	}
	mutex_unlock(&global_udp_socket_manager->table_writer_lock);
}

void udp_socket_manager_inode_free_security(const struct inode *inode)
{
	int hash = moremur_hash((uint64_t) inode, TABLE_SIZE_BITS);
	bool listed;

	if (!inode->i_sb)
		return;
	if (inode->i_sb->s_magic != SOCKFS_MAGIC)
		return;

	rcu_read_lock();
	listed = inode_is_inserted_rcu(hash, inode);
	rcu_read_unlock();

	if (!listed)
		return;

	mutex_lock(&global_udp_socket_manager->table_writer_lock);
	inode_erase_all_impl(hash, inode);
	mutex_unlock(&global_udp_socket_manager->table_writer_lock);
}

static void sweep_impl(void)
{
	while (1)
	{
		hashtable_udp_socket_node_t *node = list_first_entry_or_null(&global_udp_socket_manager->seen_flows_lru_list, hashtable_udp_socket_node_t, lru_list_node);
		if (!node)
			break;

		if (time_after(jiffies, node->lru_deadline))
			erase_impl(node);
		else
			break;
	}

	while (global_udp_socket_manager->flows_count > TABLE_SIZE)
	{
		hashtable_udp_socket_node_t *node = list_first_entry_or_null(&global_udp_socket_manager->seen_flows_lru_list, hashtable_udp_socket_node_t, lru_list_node);
		if (!node)
			break;

		erase_impl(node);
	}
}

void udp_socket_manager_sendmsg(task_info_t* task_info, struct socket *sock, struct msghdr *msg, int size)
{
	const struct inode *inode = SOCK_INODE(sock);
	hashtable_udp_socket_node_t *node;
	hashtable_udp_socket_key_common_t key;
	int hash_addr;
	int hash_inode;
	bool send = false;

	if (!socket_make_key(&key, sock, msg))
		return;

	hash_addr = key_hash(&key.header);
	hash_inode = inode_hash(inode);

	// Lookup for existing node and refresh
	rcu_read_lock();
	node = find_ref_rcu(hash_addr, &key.header);
	rcu_read_unlock();

	if (node) {
		mutex_lock(&global_udp_socket_manager->table_writer_lock);
		refresh_impl(node);
		sweep_impl();
		mutex_unlock(&global_udp_socket_manager->table_writer_lock);
		node_put(node);

		return;
	}

	// No node found, create a new one
	node = node_alloc(&key.header);
	if (!node)
		return;

	mutex_lock(&global_udp_socket_manager->table_writer_lock);
	if (global_udp_socket_manager->active) {
		hashtable_udp_socket_node_t *found_node;
		sweep_impl();

		found_node = find(hash_addr, &key.header);
		if (found_node) {
			refresh_impl(found_node);
			node_put(node);
		} else {
			hlist_add_head_rcu(&node->hash_inode_node, &global_udp_socket_manager->seen_flows_inode_hashtable[hash_inode]);
			hlist_add_head_rcu(&node->hash_addr_node , &global_udp_socket_manager->seen_flows_addr_hashtable[hash_addr]);
			list_add_tail(&node->lru_list_node, &global_udp_socket_manager->seen_flows_lru_list);
			node->lru_list_inserted = true;
			node->lru_deadline = jiffies + TTL;
			global_udp_socket_manager->flows_count++;
			send = true;
		}

	} else {
		node_put(node);
	}
	mutex_unlock(&global_udp_socket_manager->table_writer_lock);

	if (send) {
		// TODO: Analyze QUIC + DNS
		(void) size;
		net_event_sendmsg_udp(task_info, sock, msg);
	}
}
