/**
@file     raw_socket_manager.cpp
@brief    Manager for RAW and netlink sockets
@details  Copyright (c) 2025 Acronis International GmbH
@author   Denis Kopyrin (denis.kopyrin@acronis.com)
@since    $Id: $
*/

#include "raw_socket_manager.h"

#include <linux/magic.h>
#include <net/sock.h>

#include "hashtable_compat.h"
#include "memory.h"
#include "net_events.h"
#include "transport.h"

#ifndef SOCKFS_MAGIC
#define SOCKFS_MAGIC		0x534F434B
#endif

// MARK: Hashtables
#define RAW_SOCKET_EVENTS_TABLE_SIZE_BITS 8

typedef struct
{
	struct hlist_node hash_node;
	// used as a key
	const struct inode *inode;
	struct rcu_head rcu;
} hashtable_raw_socket_node_t;

typedef struct raw_socket_manager
{
	struct mutex table_writer_lock;
	bool active;
	// First 'recvmsg' after 'bind' is added 'recvmsg_sniff_hashtable' to trigger 'recvmsg'.
	// In future, it is possible to keep sockets that needs to be sniffed in this hashtable
	// if multiple 'recvmsg' sniffing is required.
	DECLARE_HASHTABLE(recvmsg_sniff_hashtable, RAW_SOCKET_EVENTS_TABLE_SIZE_BITS);
} raw_socket_manager_t;

static raw_socket_manager_t global_raw_socket_manager;

static void raw_socket_node_rcu_free(struct rcu_head *rcu)
{
	hashtable_raw_socket_node_t *node = container_of(rcu, hashtable_raw_socket_node_t, rcu);
	mem_free(node);
}

static void raw_socket_table_free(struct hlist_head* heads, int size) {
	int idx;
	for (idx = 0; idx < size; idx++) {
		while (true) {
			hashtable_raw_socket_node_t* node;
			struct hlist_node* first = heads[idx].first;
			if (!first)
				break;

			node = hlist_entry(first, hashtable_raw_socket_node_t, hash_node);
			hlist_del_init_rcu(first);
			call_rcu(&node->rcu, raw_socket_node_rcu_free);
		}
	}
}

static bool raw_socket_table_is_inserted_rcu(const struct hlist_head* heads, int hash, const struct inode *inode) {
	hashtable_raw_socket_node_t *search_node;
	hlist_for_each_entry_rcu(search_node, &heads[hash], hash_node) {
		if (search_node->inode == inode) {
			return true;
		}
	}
	return false;
}

static bool raw_socket_table_erase_impl(struct hlist_head* heads, int hash, const struct inode *inode, hashtable_raw_socket_node_t **pnode) {
	bool found = false;
	hashtable_raw_socket_node_t* node;
	hlist_for_each_entry(node, &heads[hash], hash_node) {
		if (node->inode == inode) {
			found = true;
			break;
		}
	}
	if (found)
		hlist_del_init_rcu(&node->hash_node);

	*pnode = node;
	return found;
}

static bool raw_socket_table_erase(struct hlist_head* heads, const struct inode *inode) {
	bool found = false;
	hashtable_raw_socket_node_t *node;
	unsigned int hash = moremur_hash((uint64_t) inode, RAW_SOCKET_EVENTS_TABLE_SIZE_BITS);

	mutex_lock(&global_raw_socket_manager.table_writer_lock);
	found = raw_socket_table_erase_impl(heads, hash, inode, &node);
	mutex_unlock(&global_raw_socket_manager.table_writer_lock);

	if (found) {
		call_rcu(&node->rcu, raw_socket_node_rcu_free);
	}

	return found;
}

static void raw_socket_table_insert(struct hlist_head* heads, const struct inode *inode)
{
	bool failed_to_insert = false;
	hashtable_raw_socket_node_t *search_node;
	hashtable_raw_socket_node_t *node;
	unsigned int hash = moremur_hash((uint64_t) inode, RAW_SOCKET_EVENTS_TABLE_SIZE_BITS);
	rcu_read_lock();
	if (raw_socket_table_is_inserted_rcu(heads, hash, inode)) {
		rcu_read_unlock();
		return;
	}
	rcu_read_unlock();

	node = (hashtable_raw_socket_node_t*) mem_alloc(sizeof(*node));
	if (!node) {
		return;
	}

	node->inode = inode;
	mutex_lock(&global_raw_socket_manager.table_writer_lock);
	if (global_raw_socket_manager.active) {
		hlist_for_each_entry(search_node, &heads[hash], hash_node) {
			if (search_node->inode == inode) {
				failed_to_insert = true;
				break;
			}
		}
		if (!failed_to_insert) {
			hlist_add_head_rcu(&node->hash_node, &heads[hash]);
		}
	} else {
		failed_to_insert = true;
	}
	mutex_unlock(&global_raw_socket_manager.table_writer_lock);

	if (failed_to_insert) {
		mem_free(node);
	}
}

static bool raw_socket_table_is_inserted(const struct hlist_head* heads, const struct inode *inode) {
	unsigned int hash = moremur_hash((uint64_t) inode, RAW_SOCKET_EVENTS_TABLE_SIZE_BITS);
	bool listed;
	rcu_read_lock();
	listed = raw_socket_table_is_inserted_rcu(heads, hash, inode);
	rcu_read_unlock();

	return listed;
}

static bool raw_recvmsg_sniff_is_inserted_rcu(const struct inode *inode, unsigned int hash)
{
	return raw_socket_table_is_inserted_rcu(global_raw_socket_manager.recvmsg_sniff_hashtable, hash, inode);
}

static bool raw_recvmsg_sniff_needs_handling(const struct inode *inode)
{
	struct hlist_head* heads = global_raw_socket_manager.recvmsg_sniff_hashtable;
	if (!raw_socket_table_is_inserted(heads, inode))
		return false;

	return raw_socket_table_erase(heads, inode);
}

// MARK: APIs
void raw_socket_manager_init(void)
{
	mutex_init(&global_raw_socket_manager.table_writer_lock);
	global_raw_socket_manager.active = false;
	hash_init(global_raw_socket_manager.recvmsg_sniff_hashtable);
}

void raw_socket_manager_activate(void)
{
	mutex_lock(&global_raw_socket_manager.table_writer_lock);
	global_raw_socket_manager.active = true;
	mutex_unlock(&global_raw_socket_manager.table_writer_lock);
}

void raw_socket_manager_deactivate(void)
{
	mutex_lock(&global_raw_socket_manager.table_writer_lock);
	if (global_raw_socket_manager.active) {
		raw_socket_table_free(global_raw_socket_manager.recvmsg_sniff_hashtable
		                    , (int) ARRAY_SIZE(global_raw_socket_manager.recvmsg_sniff_hashtable));
		global_raw_socket_manager.active = false;
	}
	mutex_unlock(&global_raw_socket_manager.table_writer_lock);
}

void raw_socket_manager_inode_free_security(const struct inode *inode)
{
	unsigned int hash = moremur_hash((uint64_t) inode, RAW_SOCKET_EVENTS_TABLE_SIZE_BITS);
	bool listed;
	bool recvmsg_sniff_found = false;
	hashtable_raw_socket_node_t *recvmsg_sniff_node;

	if (!inode->i_sb)
		return;
	if (inode->i_sb->s_magic != SOCKFS_MAGIC)
		return;

	rcu_read_lock();
	listed = raw_recvmsg_sniff_is_inserted_rcu(inode, hash);
	rcu_read_unlock();

	if (!listed)
		return;

	mutex_lock(&global_raw_socket_manager.table_writer_lock);
	recvmsg_sniff_found = raw_socket_table_erase_impl(global_raw_socket_manager.recvmsg_sniff_hashtable, hash, inode, &recvmsg_sniff_node);
	mutex_unlock(&global_raw_socket_manager.table_writer_lock);

	if (recvmsg_sniff_found) {
		call_rcu(&recvmsg_sniff_node->rcu, raw_socket_node_rcu_free);
	}
}

void raw_socket_manager_will_sniff_recvmsg(task_info_t* task_info, struct socket *sock)
{
	transport_ids_t transport_ids;
	const uint64_t generatedEventsMask = MSG_TYPE_TO_EVENT_MASK(FP_SI_OT_NOTIFY_SOCKET_RECVMSG_RAW);
	if (!(transport_global_get_combined_mask() & generatedEventsMask))
		return;

	transport_global_get_ids(&transport_ids, generatedEventsMask);
	if (!task_info_can_skip(task_info, &transport_ids, generatedEventsMask)) {
		raw_socket_table_insert(global_raw_socket_manager.recvmsg_sniff_hashtable, SOCK_INODE(sock));
	}
}

void raw_socket_manager_recvmsg(task_info_t* task_info, struct socket *sock)
{
	struct inode *inode = SOCK_INODE(sock);
	transport_ids_t transport_ids;
	const uint64_t generatedEventsMask = MSG_TYPE_TO_EVENT_MASK(FP_SI_OT_NOTIFY_SOCKET_RECVMSG_RAW);
	if (!raw_recvmsg_sniff_needs_handling(inode))
		return;

	if (!(transport_global_get_combined_mask() & generatedEventsMask))
		return;

	transport_global_get_ids(&transport_ids, generatedEventsMask);
	if (!task_info_can_skip(task_info, &transport_ids, generatedEventsMask)) {
		net_event_recvmsg_raw(task_info, sock);
	}
}