/***************************************************************************************************
 *    Copyright (c) 2025 Cisco Systems, Inc.
 *    All Rights Reserved. Cisco Highly Confidential.
 * 
 *    This program is free software; you can redistribute it and/or
 *    modify it under the terms of the GNU General Public License
 *    as published by the Free Software Foundation; either version 2
 *    of the License, or (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 * 
 ****************************************************************************************************
 *
 *    File:     interceptor.bpf.c
 *    Author:   shuchaud
 *    Date:     1/2025
 *
 ****************************************************************************************************
 *    Description: BPF kernel program to intercept packets and store them in a ring buffer.
 ****************************************************************************************************/

#ifdef NVM_BPF_USER
#undef NVM_BPF_USER
#endif

#ifndef NVM_BPF_KERNEL
#define NVM_BPF_KERNEL
#endif

#include "defines.h"
#include "nvm_user_kernel_types.h"
#include "dns_user_kernel_types.h"
#include <bpf/bpf_helpers.h>
#include <bpf/bpf_endian.h>
#include <bpf/bpf_core_read.h>

#define ETH_P_IP 0x0800   /* Internet Protocol packet	*/
#define ETH_P_IPV6 0x86DD /* IPv6 over bluebook		*/

#define TC_ACT_OK 0

#define DNS_PORT 53
#define MAX_VPN_IFACE_ENTRIES 16

enum log_level
{
    NVM_BPF_ERROR = 0,
    NVM_BPF_DEBUG
};

static bool verbose = false;

#define LOG(level, fmt, ...)                                          \
    do                                                                \
    {                                                                 \
        if ((level) == NVM_BPF_ERROR)                                 \
        {                                                             \
            bpf_printk("[NVM_BPF_LOGS] [ERROR] " fmt, ##__VA_ARGS__); \
        }                                                             \
        else if (verbose && NVM_BPF_DEBUG == (level))                 \
        {                                                             \
            bpf_printk("[NVM_BPF_LOGS] [DEBUG] " fmt, ##__VA_ARGS__); \
        }                                                             \
    } while (0)

struct
{
    __uint(type, BPF_MAP_TYPE_RINGBUF);
    __uint(max_entries, 1 << 24); // 16MB
} csc_ringbuf SEC(".maps");

// Map for sending data from userspace to kernel space
struct
{
    __uint(type, BPF_MAP_TYPE_ARRAY);
    __uint(max_entries, 1); // verbose mode
    __type(key, __u32);
    __type(value, __u32);
} csc_userspace_map SEC(".maps");

struct
{
    __uint(type, BPF_MAP_TYPE_HASH);
    __uint(max_entries, MAX_VPN_IFACE_ENTRIES); // max 16 VPN interfaces
    __type(key, __u32);      // VPN interface index
    __type(value, __u8);     // value (e.g., 1)
} csc_vpn_if_map SEC(".maps");

static __always_inline bool get_debug_mode(void)
{
    __u32 key = VERBOSE_KEY;
    __u32 *value = bpf_map_lookup_elem(&csc_userspace_map, &key);

    // Return the stored value (true/false) if found, otherwise default to false
    return value ? (*value > 0) : false;
}

// Function to get transport header (TCP/UDP)
static __always_inline error_code get_transport_header(void *transport_header, void *data_end, __u8 protocol, struct bpf_nw_pkt_meta *pkt, void **dns_payload)
{
    if (IPPROTO_TCP == protocol)
    {
        struct tcphdr *tcp = (struct tcphdr *)(transport_header);
        if ((void *)tcp + sizeof(struct tcphdr) > data_end)
        {
            LOG(NVM_BPF_ERROR, "TCP header exceeds data end");
            return ERROR_UNEXPECTED;
        }
        pkt->l4_header.h_tcp = *tcp;
    }
    else if (IPPROTO_UDP == protocol)
    {
        struct udphdr *udp = (struct udphdr *)(transport_header);
        if ((void *)udp + sizeof(struct udphdr) > data_end)
        {
            LOG(NVM_BPF_ERROR, "UDP header exceeds data end");
            return ERROR_UNEXPECTED;
        }
        pkt->l4_header.h_udp = *udp;

        // if DNS response, get the dns payload
        if ((DNS_PORT == bpf_ntohs(udp->source)) && INBOUND == pkt->direction)
        {
            if (bpf_ntohs(udp->len) > DNS_UDP_PACKET_MAX_SIZE)
            {
                LOG(NVM_BPF_ERROR, "DNS message payload(%d) exceeds max size(%d)", bpf_ntohs(udp->len), DNS_UDP_PACKET_MAX_SIZE);
                return ERROR_UNEXPECTED;
            }
            *dns_payload = (void *)(udp + 1);
        }
    }
    else
    {
        LOG(NVM_BPF_DEBUG, "Unsupported transport protocol 0x%x", protocol);
        return (error_code)ERROR_NOT_SUPPORTED;
    }
    return (error_code)ERROR_SUCCESS;
}

// Function to get IP header (IPv4/IPv6)
static __always_inline error_code get_ip_header(void *data, void *data_end, struct bpf_nw_pkt_meta *pkt, void **dns_payload)
{
    struct iphdr *ip = (struct iphdr *)(data);

    // Verify we can safely access the IP header
    if ((void *)ip + sizeof(struct iphdr) > data_end)
    {
        LOG(NVM_BPF_ERROR, "VPN IP header exceeds data end");
        return ERROR_UNEXPECTED;
    }

    // Check IP version
    if (ip->version == 4)
    {
        // IPv4 packet
        pkt->ip_header.h_ip4 = *ip;
        void *transport_header = (void *)ip + sizeof(struct iphdr);
        return get_transport_header(transport_header, data_end, ip->protocol, pkt, dns_payload);
    }
    else if (ip->version == 6)
    {
        // IPv6 packet
        struct ipv6hdr *ip6 = (struct ipv6hdr *)(data);
        if ((void *)ip6 + sizeof(struct ipv6hdr) > data_end)
        {
            LOG(NVM_BPF_ERROR, "VPN IPv6 header exceeds data end");
            return ERROR_UNEXPECTED;
        }
        pkt->ip_header.h_ip6 = *ip6;
        void *transport_header = (void *)ip6 + sizeof(struct ipv6hdr);
        return get_transport_header(transport_header, data_end, ip6->nexthdr, pkt, dns_payload);
    }
    else
    {
        LOG(NVM_BPF_DEBUG, "Invalid IP version: %d", ip->version);
        return (error_code)ERROR_NOT_SUPPORTED;
    }
}

static __always_inline error_code send_dns_data(struct bpf_nw_pkt_meta *pkt, void *dns_payload)
{
    size_t map_size = sizeof(struct bpf_nw_pkt_meta) + DNS_UDP_PACKET_MAX_SIZE;
    void *map_data = bpf_ringbuf_reserve(&csc_ringbuf, map_size, 0);
    if (!map_data)
    {
        LOG(NVM_BPF_ERROR, "Failed to reserve ring buffer");
        return ERROR_UNEXPECTED;
    }

    __builtin_memcpy(map_data, pkt, sizeof(struct bpf_nw_pkt_meta));
    void *payload_dst = map_data + sizeof(struct bpf_nw_pkt_meta);
    if (bpf_probe_read_kernel(payload_dst, DNS_UDP_PACKET_MAX_SIZE, dns_payload) != 0)
    {
        bpf_ringbuf_discard(map_data, 0);
        LOG(NVM_BPF_ERROR, "Failed to read DNS payload");
        return ERROR_UNEXPECTED;
    }
    bpf_ringbuf_submit(map_data, 0);
    return (error_code)ERROR_SUCCESS;
}

// Function to handle TC packets
int handle_tc_packet(struct __sk_buff *skb)
{
    verbose = get_debug_mode();

    struct bpf_nw_pkt_meta pkt = {};
    void *dns_payload = NULL;
    struct task_struct *task;
    error_code error;

    // Store needed skb fields in local variables right away
    __u32 ingress_ifindex = skb->ingress_ifindex;
    __u32 ifindex = skb->ifindex;
    void *data = (void *)(long)skb->data;
    void *data_end = (void *)(long)skb->data_end;

    // Determine direction using the stored values
    pkt.direction = ingress_ifindex ? INBOUND : OUTBOUND;

    // Use local variables for active interface calculation
    __u32 active_ifindex;
    if (pkt.direction == INBOUND)
    {
        active_ifindex = ingress_ifindex;
    }
    else
    {
        active_ifindex = ifindex;
    }

    // Check if this is a VPN interface
    __u8 *present = bpf_map_lookup_elem(&csc_vpn_if_map, &active_ifindex);
    if(present)
    {
        // This is a VPN interface - packets don't have Ethernet headers
        LOG(NVM_BPF_DEBUG, "Processing packet from VPN interface %u", active_ifindex);
        error = get_ip_header(data, data_end, &pkt, &dns_payload);
    }
    else
    {
        // Regular interface with Ethernet headers
        struct ethhdr *eth = data;
        if (data + sizeof(struct ethhdr) > data_end)
        {
            LOG(NVM_BPF_ERROR, "Ethernet header exceeds data end");
            return TC_ACT_OK;
        }
        error = get_ip_header(eth + 1, data_end, &pkt, &dns_payload);
    }

    if ((error_code)ERROR_SUCCESS != error && (error_code)ERROR_NOT_SUPPORTED != error)
    {
        LOG(NVM_BPF_ERROR, "Failed to get packet headers");
        return TC_ACT_OK;
    }

    task = (struct task_struct *)bpf_get_current_task();
    if (!task)
    {
        LOG(NVM_BPF_ERROR, "Failed to get current task");
        return TC_ACT_OK;
    }

    pkt.pid = BPF_CORE_READ(task, tgid);
    // Get the task_struct of the thread group leader
    struct task_struct *group_leader = BPF_CORE_READ(task, group_leader);
    if (!group_leader) {
        LOG(NVM_BPF_ERROR, "Failed to get thread group leader");
        return TC_ACT_OK;
    }

    // Read start_time from the thread group leader
    pkt.process_creation_time = BPF_CORE_READ(group_leader, start_time) / 1000000; // conversion from nanoseconds to milliseconds

    pkt.header.version = NVM_MESSAGE_APPFLOW_DATA;
    pkt.header.length = sizeof(struct bpf_nw_pkt_meta);

    if (NULL != dns_payload) // if dns response packet
    {
        pkt.header.length += DNS_UDP_PACKET_MAX_SIZE;
        error = send_dns_data(&pkt, dns_payload);
        if (ERROR_SUCCESS != error)
        {
            LOG(NVM_BPF_ERROR, "Failed to send DNS data");
        }
    }
    else
    {
        if (bpf_ringbuf_output(&csc_ringbuf, &pkt, sizeof(struct bpf_nw_pkt_meta), 0) != 0)
        {
            LOG(NVM_BPF_ERROR, "Failed to output to ring buffer");
        }
    }

    return TC_ACT_OK;
}

// TC ingress program
SEC("tc")
int csc_tc_prog_ingress(struct __sk_buff *skb)
{
    return handle_tc_packet(skb);
}

// TC egress program
SEC("tc")
int csc_tc_prog_egress(struct __sk_buff *skb)
{
    return handle_tc_packet(skb);
}

char _license[] SEC("license") = "GPL";
