/**
@file     tcp_socket_manager.c
@brief    Manager for TCP socket flow
@details  Copyright (c) 2024 Acronis International GmbH
@author   Denis Kopyrin (denis.kopyrin@acronis.com)
@since    $Id: $
*/

#include "tcp_socket_manager.h"

#include <linux/magic.h>
#include <linux/uio.h>
#include <net/sock.h>

#include "hashtable_compat.h"
#include "memory.h"
#include "net_compat.h"
#include "task_info_map.h"
#include "transport/transport.h"
#include "net_events.h"

#ifdef KERNEL_MOCK
#include "mock/mock.h"
#endif

#ifndef SOCKFS_MAGIC
#define SOCKFS_MAGIC		0x534F434B
#endif

// MARK: Hashtables
#define TCP_SOCKET_EVENTS_TABLE_SIZE_BITS 8

typedef struct
{
	struct hlist_node hash_node;
	// used as a key
	const struct inode *inode;
	struct rcu_head rcu;
} hashtable_tcp_socket_node_t;

typedef struct tcp_socket_manager
{
	struct mutex table_writer_lock;
	bool active;
	// When 'accept' is called, 'socket' is added in 'post_accept_hashtable' to trigger 'accept'
	// LSM does not provide the address where 'socket' was accepted on till after LSM hook is done.
	DECLARE_HASHTABLE(post_accept_hashtable, TCP_SOCKET_EVENTS_TABLE_SIZE_BITS);
	// First 'sendmsg' after 'connect' is added 'sendmsg_sniff_hashtable' to trigger 'sendmsg'.
	// In future, it is possible to keep sockets that needs to be sniffed in this hashtable
	// if multiple 'sendmsg' sniffing is required.
	DECLARE_HASHTABLE(sendmsg_sniff_hashtable, TCP_SOCKET_EVENTS_TABLE_SIZE_BITS);
} tcp_socket_manager_t;

static tcp_socket_manager_t global_tcp_socket_manager;

static void tcp_socket_node_rcu_free(struct rcu_head *rcu)
{
	hashtable_tcp_socket_node_t *node = container_of(rcu, hashtable_tcp_socket_node_t, rcu);
	mem_free(node);
}

static void tcp_socket_table_free(struct hlist_head* heads, int size) {
	int idx;
	for (idx = 0; idx < size; idx++) {
		while (true) {
			hashtable_tcp_socket_node_t* node;
			struct hlist_node* first = heads[idx].first;
			if (!first)
				break;

			node = hlist_entry(first, hashtable_tcp_socket_node_t, hash_node);
			hlist_del_init_rcu(first);
			call_rcu(&node->rcu, tcp_socket_node_rcu_free);
		}
	}
}

static bool tcp_socket_table_is_inserted_rcu(const struct hlist_head* heads, int hash, const struct inode *inode) {
	hashtable_tcp_socket_node_t *search_node;
	hlist_for_each_entry_rcu(search_node, &heads[hash], hash_node) {
		if (search_node->inode == inode) {
			return true;
		}
	}
	return false;
}

static bool tcp_socket_table_erase_impl(struct hlist_head* heads, int hash, const struct inode *inode, hashtable_tcp_socket_node_t **pnode) {
	bool found = false;
	hashtable_tcp_socket_node_t* node;
	hlist_for_each_entry(node, &heads[hash], hash_node) {
		if (node->inode == inode) {
			found = true;
			break;
		}
	}
	if (found)
		hlist_del_init_rcu(&node->hash_node);

	*pnode = node;
	return found;
}

static bool tcp_socket_table_erase(struct hlist_head* heads, const struct inode *inode) {
	bool found = false;
	hashtable_tcp_socket_node_t *node;
	unsigned int hash = moremur_hash((uint64_t) inode, TCP_SOCKET_EVENTS_TABLE_SIZE_BITS);

	mutex_lock(&global_tcp_socket_manager.table_writer_lock);
	found = tcp_socket_table_erase_impl(heads, hash, inode, &node);
	mutex_unlock(&global_tcp_socket_manager.table_writer_lock);

	if (found) {
		call_rcu(&node->rcu, tcp_socket_node_rcu_free);
	}

	return found;
}

static void tcp_socket_table_insert(struct hlist_head* heads, const struct inode *inode)
{
	bool failed_to_insert = false;
	hashtable_tcp_socket_node_t *search_node;
	hashtable_tcp_socket_node_t *node;
	unsigned int hash = moremur_hash((uint64_t) inode, TCP_SOCKET_EVENTS_TABLE_SIZE_BITS);
	rcu_read_lock();
	if (tcp_socket_table_is_inserted_rcu(heads, hash, inode)) {
		rcu_read_unlock();
		return;
	}
	rcu_read_unlock();

	node = (hashtable_tcp_socket_node_t*) mem_alloc(sizeof(*node));
	if (!node) {
		return;
	}

	node->inode = inode;
	mutex_lock(&global_tcp_socket_manager.table_writer_lock);
	if (global_tcp_socket_manager.active) {
		hlist_for_each_entry(search_node, &heads[hash], hash_node) {
			if (search_node->inode == inode) {
				failed_to_insert = true;
				break;
			}
		}
		if (!failed_to_insert) {
			hlist_add_head_rcu(&node->hash_node, &heads[hash]);
		}
	} else {
		failed_to_insert = true;
	}
	mutex_unlock(&global_tcp_socket_manager.table_writer_lock);

	if (failed_to_insert) {
		mem_free(node);
	}
}

static bool tcp_socket_table_is_inserted(const struct hlist_head* heads, const struct inode *inode) {
	unsigned int hash = moremur_hash((uint64_t) inode, TCP_SOCKET_EVENTS_TABLE_SIZE_BITS);
	bool listed;
	rcu_read_lock();
	listed = tcp_socket_table_is_inserted_rcu(heads, hash, inode);
	rcu_read_unlock();

	return listed;
}

static bool tcp_post_accept_is_inserted_rcu(const struct inode *inode, unsigned int hash)
{
	return tcp_socket_table_is_inserted_rcu(global_tcp_socket_manager.post_accept_hashtable, hash, inode);
}

static bool tcp_sendmsg_sniff_is_inserted_rcu(const struct inode *inode, unsigned int hash)
{
	return tcp_socket_table_is_inserted_rcu(global_tcp_socket_manager.sendmsg_sniff_hashtable, hash, inode);
}

static bool tcp_post_accept_needs_handling(const struct inode *inode)
{
	struct hlist_head* heads = global_tcp_socket_manager.post_accept_hashtable;
	if (!tcp_socket_table_is_inserted(heads, inode))
		return false;

	return tcp_socket_table_erase(heads, inode);
}

static bool tcp_sendmsg_sniff_needs_handling(const struct inode *inode)
{
	struct hlist_head* heads = global_tcp_socket_manager.sendmsg_sniff_hashtable;
	if (!tcp_socket_table_is_inserted(heads, inode))
		return false;

	return tcp_socket_table_erase(heads, inode);
}

// MARK: APIs
void tcp_socket_manager_init(void)
{
	mutex_init(&global_tcp_socket_manager.table_writer_lock);
	global_tcp_socket_manager.active = false;
	hash_init(global_tcp_socket_manager.post_accept_hashtable);
	hash_init(global_tcp_socket_manager.sendmsg_sniff_hashtable);
}

void tcp_socket_manager_activate(void)
{
	mutex_lock(&global_tcp_socket_manager.table_writer_lock);
	global_tcp_socket_manager.active = true;
	mutex_unlock(&global_tcp_socket_manager.table_writer_lock);
}

void tcp_socket_manager_deactivate(void)
{
	mutex_lock(&global_tcp_socket_manager.table_writer_lock);
	if (global_tcp_socket_manager.active) {
		tcp_socket_table_free(global_tcp_socket_manager.post_accept_hashtable
		                    , (int) ARRAY_SIZE(global_tcp_socket_manager.post_accept_hashtable));
		tcp_socket_table_free(global_tcp_socket_manager.sendmsg_sniff_hashtable
		                    , (int) ARRAY_SIZE(global_tcp_socket_manager.sendmsg_sniff_hashtable));
		global_tcp_socket_manager.active = false;
	}
	mutex_unlock(&global_tcp_socket_manager.table_writer_lock);
}

void tcp_socket_manager_inode_free_security(const struct inode *inode)
{
	unsigned int hash = moremur_hash((uint64_t) inode, TCP_SOCKET_EVENTS_TABLE_SIZE_BITS);
	bool listed;
	bool post_accept_found = false;
	bool sendmsg_sniff_found = false;
	hashtable_tcp_socket_node_t *post_accept_node;
	hashtable_tcp_socket_node_t *sendmsg_sniff_node;

	if (!inode->i_sb)
		return;
	if (inode->i_sb->s_magic != SOCKFS_MAGIC)
		return;

	rcu_read_lock();
	listed = tcp_post_accept_is_inserted_rcu  (inode, hash)
	       | tcp_sendmsg_sniff_is_inserted_rcu(inode, hash);
	rcu_read_unlock();

	if (!listed)
		return;

	mutex_lock(&global_tcp_socket_manager.table_writer_lock);
	post_accept_found   = tcp_socket_table_erase_impl(global_tcp_socket_manager.post_accept_hashtable  , hash, inode, &post_accept_node);
	sendmsg_sniff_found = tcp_socket_table_erase_impl(global_tcp_socket_manager.sendmsg_sniff_hashtable, hash, inode, &sendmsg_sniff_node);
	mutex_unlock(&global_tcp_socket_manager.table_writer_lock);

	if (post_accept_found) {
		call_rcu(&post_accept_node->rcu, tcp_socket_node_rcu_free);
	}
	if (sendmsg_sniff_found) {
		call_rcu(&sendmsg_sniff_node->rcu, tcp_socket_node_rcu_free);
	}
}

void tcp_socket_manager_will_post_accept(task_info_t* task_info, struct socket *sock)
{
	transport_ids_t transport_ids;
	const uint64_t generatedEventsMask = MSG_TYPE_TO_EVENT_MASK(FP_SI_OT_NOTIFY_SOCKET_ACCEPT);
	const struct inode *inode = SOCK_INODE(sock);
	if (!(transport_global_get_combined_mask() & generatedEventsMask))
		return;

	transport_global_get_ids(&transport_ids, generatedEventsMask);
	if (task_info_can_skip(task_info, &transport_ids, generatedEventsMask))
		return;

	return tcp_socket_table_insert(global_tcp_socket_manager.post_accept_hashtable, inode);
}

void tcp_socket_manager_will_sniff_sendmsg(task_info_t* task_info, struct socket *sock)
{
	transport_ids_t transport_ids;
	const uint64_t generatedEventsMask = MSG_TYPE_TO_EVENT_MASK(FP_SI_OT_NOTIFY_SOCKET_SENDMSG_TCP);
	if (!(transport_global_get_combined_mask() & generatedEventsMask))
		return;

	transport_global_get_ids(&transport_ids, generatedEventsMask);
	if (!task_info_can_skip(task_info, &transport_ids, generatedEventsMask)) {
		tcp_socket_table_insert(global_tcp_socket_manager.sendmsg_sniff_hashtable, SOCK_INODE(sock));
	}
}

void tcp_socket_manager_check_post_accept(task_info_t* task_info, struct socket *sock)
{
	transport_ids_t transport_ids;
	const uint64_t generatedEventsMask = MSG_TYPE_TO_EVENT_MASK(FP_SI_OT_NOTIFY_SOCKET_ACCEPT);
	const struct inode *inode = SOCK_INODE(sock);
	if (!tcp_post_accept_needs_handling(inode))
		return;

	if (!(transport_global_get_combined_mask() & generatedEventsMask))
		return;

	transport_global_get_ids(&transport_ids, generatedEventsMask);
	if (!task_info_can_skip(task_info, &transport_ids, generatedEventsMask)) {
		net_event_post_accept(task_info, sock);
	}
}

void tcp_socket_manager_sendmsg(task_info_t* task_info, struct socket *sock, struct msghdr *msg, int size)
{
	struct inode *inode = SOCK_INODE(sock);
	transport_ids_t transport_ids;
	const uint64_t generatedEventsMask = MSG_TYPE_TO_EVENT_MASK(FP_SI_OT_NOTIFY_SOCKET_SENDMSG_TCP);

	(void) msg; (void) size;
	if (!tcp_sendmsg_sniff_needs_handling(inode))
		return;

	if (!(transport_global_get_combined_mask() & generatedEventsMask))
		return;

	transport_global_get_ids(&transport_ids, generatedEventsMask);
	if (!task_info_can_skip(task_info, &transport_ids, generatedEventsMask)) {
		net_event_sendmsg_tcp(task_info, sock);
	}
}
