/**
@file     net_events.c
@brief    Network based events
@details  Copyright (c) 2024 Acronis International GmbH
@author   Denis Kopyrin (denis.kopyrin@acronis.com)
@since    $Id: $
*/

#include "net_events.h"

#include "net_compat.h"
#include "transport/message.h"
#include "si_templates.h"
#include "si_writer.h"
#include "si_writer_common.h"

#include <asm/byteorder.h>
#include <linux/in.h>
#include <linux/in6.h>
#include <linux/netdevice.h>
#include <linux/string.h>
#include <net/sock.h>
#include <net/ipv6.h>

#ifdef KERNEL_MOCK
#include "mock/mock.h"
#endif

#define WRITE_SOCKADDR(name) \
static void si_property_writer_write_##name##_inet_sockaddr(si_property_writer_t *writer, const struct sockaddr *addr, int addr_len) { \
	switch (addr->sa_family) { \
		case AF_INET: \
		if (addr_len >= (int) sizeof(struct sockaddr_in)) \
		{ \
			SiSizedBuffer addr_buf; \
			struct sockaddr_in *addr4 = (struct sockaddr_in *) addr; \
			addr_buf.value = (uint8_t *) &addr4->sin_addr; \
			addr_buf.length = sizeof(addr4->sin_addr); \
			si_property_writer_write_socket_##name##_address_ip(writer, addr_buf); \
			si_property_writer_write_socket_##name##_port(writer, ntohs(addr4->sin_port)); \
		} \
		break; \
		case AF_INET6: \
		if (addr_len >= SIN6_LEN_RFC2133) { \
			SiSizedBuffer addr_buf; \
			struct sockaddr_in6 *addr6 = (struct sockaddr_in6 *) addr; \
			addr_buf.value = (uint8_t *) &addr6->sin6_addr; \
			addr_buf.length = sizeof(addr6->sin6_addr); \
			si_property_writer_write_socket_##name##_address_ip(writer, addr_buf); \
			si_property_writer_write_socket_##name##_port(writer, ntohs(addr6->sin6_port)); \
		} \
		break; \
	} \
}

WRITE_SOCKADDR(local)
WRITE_SOCKADDR(remote)

static void si_property_writer_write_packet_socket_target(si_property_writer_t* writer, const struct socket* sock, const struct sockaddr* sa, int len)
{
	if (len <= (int) sizeof(sa_family_t))
		return;

	if (sa->sa_family != AF_PACKET)
		return;

	if (SOCK_PACKET == sock->type)
	{
		// SOCK_PACKET gives the name of the device but make sure it is sane so copy it manually
		char name[IFNAMSIZ];
		int nameLen = len - sizeof(sa_family_t);
		if (nameLen > IFNAMSIZ - 1)
			nameLen = IFNAMSIZ - 1;

		memcpy(name, sa->sa_data, nameLen);
		name[nameLen] = '\0';
		si_property_writer_write_target_name(writer, (SiSizedString){ .value = name, .length = strlen(name) });
	}
	else
	{
		// Explicitly ask for high enough kernel version. It does not matter for functionality
		// because officially only kernels >=3.10 are EDR supported.
		// Version check must be synchronized with 'DRIVER_FEATURE_NETWORK_PACKET_DEVICE_NAME' flag.
#if LINUX_VERSION_CODE >= KERNEL_VERSION(2, 6, 33)
		// ...and all other (RAW, DGRAM, etc) sockets give us the device index.
		// Use it to get the device name.
		struct sockaddr_ll *sll = (struct sockaddr_ll*) sa;
		struct net_device *dev;
		if (len < (int) (offsetof(struct sockaddr_ll, sll_ifindex) + sizeof(int)))
			return;

		if (sll->sll_ifindex <= 0)
			return;

		rcu_read_lock();
		dev = dev_get_by_index_rcu(sock_net(sock->sk), sll->sll_ifindex);
		if (dev)
		{
			SiSizedString name;
			name.value = dev->name;
			name.length = strnlen(dev->name, sizeof(dev->name));
			si_property_writer_write_target_name(writer, name);
		}
		rcu_read_unlock();
#endif
	}
}


static int sockaddr_varsize(const struct sockaddr *addr, int addr_len) {
	switch (addr->sa_family) {
		case AF_INET:
		if (addr_len >= (int) sizeof(struct sockaddr_in))
		{
			struct sockaddr_in *addr4 = (struct sockaddr_in *) addr;
			return sizeof(addr4->sin_addr);
		}
		break;
		case AF_INET6:
		if (addr_len >= SIN6_LEN_RFC2133) {
			struct sockaddr_in6 *addr6 = (struct sockaddr_in6 *) addr;
			return sizeof(addr6->sin6_addr);
		}
		break;
		case AF_PACKET:
			return IFNAMSIZ;
	}

	return 0;
}

static void si_property_writer_write_socket(si_property_writer_t *writer, const struct socket *sock) {
	si_property_writer_write_socket_family(writer, (uint8_t) sock->sk->sk_family);
	si_property_writer_write_socket_protocol(writer, (uint8_t) sock->type);
}

#define SI_TMPL_SOCKET_BIND(tmpl) \
	SI_COMMON_FIELDS(tmpl) \
	tmpl(FP_SI_PI_SOCKET_FAMILY) \
	tmpl(FP_SI_PI_SOCKET_PROTOCOL) \
	tmpl(FP_SI_PI_SOCKET_LOCAL_PORT) \
	tmpl(FP_SI_PI_SOCKET_LOCAL_ADDRESS_IP) \
	tmpl(FP_SI_PI_TARGET_NAME)

void net_event_bind(task_info_t* task_info, struct socket *sock, struct sockaddr *addr, int addr_len)
{
	uint64_t unique_pid = make_unique_pid(current);
	uint64_t event_uid;
	bool is_packet = (AF_PACKET == sock->sk->sk_family);
	const uint32_t event_size = SI_ESTIMATE_TMPL_SIZE(SI_TMPL_SOCKET_BIND) + sockaddr_varsize(addr, addr_len);
	msg_t *msg = msg_new(FP_SI_OT_NOTIFY_SOCKET_BIND
	                   , 0
	                   , SI_CT_PRE_CALLBACK
	                   , unique_pid
	                   , event_size);
	if (!msg)
		goto end;

	event_uid = transport_global_sequence_next();

	{
		si_property_writer_t writer;
		si_event_writer_init(&writer, &msg->event, event_size);
		si_property_writer_write_common(&writer, event_uid, current->pid, current->tgid, task_info);
		si_property_writer_write_socket(&writer, sock);
		if (is_packet)
		{
			si_property_writer_write_packet_socket_target(&writer, sock, addr, addr_len);
		}
		else
		{
			si_property_writer_write_local_inet_sockaddr(&writer, addr, addr_len);
		}
		si_event_writer_finalize(&msg->event, &writer);
	}

	msg->task_info = task_info_get(task_info);
	msg->id = event_uid;

end:
	return send_msg_async_unref(msg);
}

#define SI_TMPL_SOCKET_CREATE(tmpl) \
	SI_COMMON_FIELDS(tmpl) \
	tmpl(FP_SI_PI_SOCKET_FAMILY) \
	tmpl(FP_SI_PI_SOCKET_PROTOCOL) \
	tmpl(FP_SI_PI_SOCKET_REMOTE_PORT) \
	tmpl(FP_SI_PI_SOCKET_REMOTE_ADDRESS_IP) \
	tmpl(FP_SI_PI_SOCKET_LOCAL_PORT) \
	tmpl(FP_SI_PI_SOCKET_LOCAL_ADDRESS_IP)

static void net_event_send_generic(task_info_t* task_info
                                 , uint16_t operation
                                 , uint16_t callback_type
                                 , const struct socket *sock
                                 , const struct sockaddr *local_addr
                                 , int local_addr_len
                                 , const struct sockaddr *remote_addr
                                 , int remote_addr_len)
{
	uint64_t unique_pid = make_unique_pid(current);
	uint64_t event_uid;
	const uint32_t event_size = SI_ESTIMATE_TMPL_SIZE(SI_TMPL_SOCKET_CREATE)
	                          + sockaddr_varsize(local_addr, local_addr_len)
	                          + sockaddr_varsize(remote_addr, remote_addr_len);
	msg_t *msg = msg_new(operation
	                   , 0
	                   , callback_type
	                   , unique_pid
	                   , event_size);
	if (!msg)
		goto end;

	event_uid = transport_global_sequence_next();

	{
		si_property_writer_t writer;
		si_event_writer_init(&writer, &msg->event, event_size);
		si_property_writer_write_common(&writer, event_uid, current->pid, current->tgid, task_info);
		si_property_writer_write_socket(&writer, sock);
		si_property_writer_write_local_inet_sockaddr(&writer, local_addr, local_addr_len);
		si_property_writer_write_remote_inet_sockaddr(&writer, remote_addr, remote_addr_len);
		si_event_writer_finalize(&msg->event, &writer);
	}

	msg->task_info = task_info_get(task_info);
	msg->id = event_uid;

end:
	return send_msg_async_unref(msg);
}

void net_event_connect(task_info_t* task, struct socket *sock, struct sockaddr *remote_addr, int remote_addr_len)
{
	// TODO: Usually this local address is empty, do we even need it?
	struct sockaddr_storage local_addr;
	int local_addr_len = sock_to_addr(sock, &local_addr, PEER_LOCAL);
	return net_event_send_generic(task, FP_SI_OT_NOTIFY_SOCKET_CONNECT, SI_CT_PRE_CALLBACK, sock
	                            , (struct sockaddr*) &local_addr, local_addr_len
	                            , remote_addr                   , remote_addr_len);
}

#define SI_TMPL_SOCKET_CREATE_RAW(tmpl) \
	SI_COMMON_FIELDS(tmpl) \
	tmpl(FP_SI_PI_SOCKET_FAMILY) \
	tmpl(FP_SI_PI_SOCKET_PROTOCOL)

void net_event_create_raw(task_info_t* task, int family, int protocol)
{
	uint64_t unique_pid = make_unique_pid(current);
	uint64_t event_uid;
	const uint32_t event_size = SI_ESTIMATE_TMPL_SIZE(SI_TMPL_SOCKET_CREATE_RAW);
	msg_t *msg = msg_new(FP_SI_OT_NOTIFY_SOCKET_CREATE_RAW
	                   , 0
	                   , SI_CT_PRE_CALLBACK
	                   , unique_pid
	                   , event_size);
	if (!msg)
		goto end;

	event_uid = transport_global_sequence_next();

	{
		si_property_writer_t writer;
		si_event_writer_init(&writer, &msg->event, event_size);
		si_property_writer_write_common(&writer, event_uid, current->pid, current->tgid, task);
		si_property_writer_write_socket_family(&writer, (uint8_t) family);
		si_property_writer_write_socket_protocol(&writer, (uint8_t) protocol);
		si_event_writer_finalize(&msg->event, &writer);
	}

	msg->task_info = task_info_get(task);
	msg->id = event_uid;

end:
	return send_msg_async_unref(msg);
}

void net_event_post_accept(task_info_t* task, struct socket *sock)
{
	struct sockaddr_storage remote_addr;
	int remote_addr_len = sock_to_addr(sock, &remote_addr, PEER_REMOTE_ALWAYS);
	struct sockaddr_storage local_addr;
	int local_addr_len = sock_to_addr(sock, &local_addr, PEER_LOCAL);
	return net_event_send_generic(task, FP_SI_OT_NOTIFY_SOCKET_ACCEPT, SI_CT_POST_CALLBACK, sock
	                            , (struct sockaddr*) &local_addr , local_addr_len
	                            , (struct sockaddr*) &remote_addr, remote_addr_len);
}

#define SI_TMPL_SOCKET_SENDMSG(tmpl) \
	SI_COMMON_FIELDS(tmpl) \
	tmpl(FP_SI_PI_SOCKET_FAMILY) \
	tmpl(FP_SI_PI_SOCKET_PROTOCOL) \
	tmpl(FP_SI_PI_SOCKET_REMOTE_PORT) \
	tmpl(FP_SI_PI_SOCKET_REMOTE_ADDRESS_IP) \
	tmpl(FP_SI_PI_SOCKET_LOCAL_PORT) \
	tmpl(FP_SI_PI_SOCKET_LOCAL_ADDRESS_IP)

static void net_event_sendmsg(task_info_t* task_info, struct socket *sock, struct msghdr *msghdr, uint16_t operation)
{
	uint64_t unique_pid = make_unique_pid(current);
	uint64_t event_uid;
	struct sockaddr_storage remote_storage_addr;
	int remote_addr_len;
	struct sockaddr_storage local_storage_addr;
	int local_addr_len = sock_to_addr(sock, &local_storage_addr, PEER_LOCAL);
	const struct sockaddr *remote_addr;
	const struct sockaddr *local_addr = (const struct sockaddr*) &local_storage_addr;
	uint32_t event_size;
	msg_t* msg;

	if (!msghdr || !msghdr->msg_name) {
		remote_addr_len = sock_to_addr(sock, &remote_storage_addr, PEER_REMOTE_ALWAYS);
		remote_addr = (const struct sockaddr*) &remote_storage_addr;
	} else {
		remote_addr_len = msghdr->msg_namelen;
		remote_addr = (const struct sockaddr*) msghdr->msg_name;
	}

	event_size = SI_ESTIMATE_TMPL_SIZE(SI_TMPL_SOCKET_SENDMSG)
	           + sockaddr_varsize(remote_addr, remote_addr_len)
	           + sockaddr_varsize(local_addr , local_addr_len);
	msg = msg_new(operation
	            , 0
	            , SI_CT_PRE_CALLBACK
	            , unique_pid
	            , event_size);
	if (!msg)
		goto end;

	event_uid = transport_global_sequence_next();

	{
		si_property_writer_t writer;
		si_event_writer_init(&writer, &msg->event, event_size);
		si_property_writer_write_common(&writer, event_uid, current->pid, current->tgid, task_info);
		si_property_writer_write_socket(&writer, sock);
		si_property_writer_write_remote_inet_sockaddr(&writer, remote_addr, remote_addr_len);
		si_property_writer_write_local_inet_sockaddr(&writer, local_addr, local_addr_len);
		si_event_writer_finalize(&msg->event, &writer);
	}

	msg->task_info = task_info_get(task_info);
	msg->id = event_uid;

end:
	return send_msg_async_unref(msg);
}

void net_event_sendmsg_tcp(task_info_t* task_info, struct socket *sock)
{
	return net_event_sendmsg(task_info, sock, NULL /*!msghdr because stateful connection*/, FP_SI_OT_NOTIFY_SOCKET_SENDMSG_TCP);
}

void net_event_sendmsg_udp(task_info_t* task_info, struct socket *sock, struct msghdr *msg)
{
	return net_event_sendmsg(task_info, sock, msg, FP_SI_OT_NOTIFY_SOCKET_SENDMSG_UDP);
}

static void si_property_writer_write_packet_socket_target_sock(si_property_writer_t* writer, struct socket *sock)
{
	struct sockaddr_storage storage_addr;
	// Must be 0 - every other value will return -EOPNOTSUPP
	int addr_len = sock_to_addr(sock, &storage_addr, 0);
	if (addr_len <= (int) sizeof(sa_family_t))
		return;

	return si_property_writer_write_packet_socket_target(writer, sock, (struct sockaddr*) &storage_addr, addr_len);
}

#define SI_TMPL_SOCKET_RECVMSG_RAW(tmpl) \
	SI_COMMON_FIELDS(tmpl) \
	tmpl(FP_SI_PI_SOCKET_FAMILY) \
	tmpl(FP_SI_PI_SOCKET_PROTOCOL) \
	tmpl(FP_SI_PI_TARGET_NAME)

void net_event_recvmsg_raw(task_info_t* task_info, struct socket *sock)
{
	uint64_t unique_pid = make_unique_pid(current);
	uint64_t event_uid;
	bool is_packet = (AF_PACKET == sock->sk->sk_family);
	uint32_t event_size = SI_ESTIMATE_TMPL_SIZE(SI_TMPL_SOCKET_RECVMSG_RAW) + (is_packet ? IFNAMSIZ : 0);
	msg_t* msg = msg_new(FP_SI_OT_NOTIFY_SOCKET_RECVMSG_RAW
	                   , 0
	                   , SI_CT_PRE_CALLBACK
	                   , unique_pid
	                   , event_size);
	if (!msg)
		goto end;

	event_uid = transport_global_sequence_next();

	{
		si_property_writer_t writer;
		si_event_writer_init(&writer, &msg->event, event_size);
		si_property_writer_write_common(&writer, event_uid, current->pid, current->tgid, task_info);
		si_property_writer_write_socket(&writer, sock);
		if (is_packet)
		{
			si_property_writer_write_packet_socket_target_sock(&writer, sock);
		}
		si_event_writer_finalize(&msg->event, &writer);
	}

	msg->task_info = task_info_get(task_info);
	msg->id = event_uid;

end:
	return send_msg_async_unref(msg);
}

#define SI_TMPL_AUTH_LOG(tmpl) \
	SI_COMMON_FIELDS(tmpl) \
	tmpl(FP_SI_PI_LOG_STR)

void net_event_auth_log(task_info_t* task_info, string_view_t str, bool success)
{
	uint64_t unique_pid = make_unique_pid(current);
	uint64_t event_uid;
	const uint32_t event_size = SI_ESTIMATE_TMPL_SIZE(SI_TMPL_AUTH_LOG) + str.len;
	msg_t *msg = msg_new(success ? FP_SI_OT_NOTIFY_SOCKET_AUTH_LOG_SUCCESS : FP_SI_OT_NOTIFY_SOCKET_AUTH_LOG_FAILED
	                   , 0
	                   , SI_CT_POST_CALLBACK
	                   , unique_pid
	                   , event_size);
	if (!msg)
		goto end;

	event_uid = transport_global_sequence_next();

	{
		si_property_writer_t writer;
		si_event_writer_init(&writer, &msg->event, event_size);
		si_property_writer_write_common(&writer, event_uid, current->pid, current->tgid, task_info);
		{
			SiSizedString si_str;
			si_str.value = str.str;
			si_str.length = str.len;
			si_property_writer_write_log_str(&writer, si_str);
		}
		si_event_writer_finalize(&msg->event, &writer);
	}

	msg->task_info = task_info_get(task_info);
	msg->id = event_uid;

end:
	return send_msg_async_unref(msg);
}
