Switch-Router

跟我一起学习drgn(4)---实现'ss -nlpt'

Published at 2023-12-12 | Last Update 2023-12-12

ss -nlpt 是查看本机tcp监听端口的常用命令, 而本文将使用drgn重新实现它. 目的是为了证明一点: 作为一种快照式的调试工具, drgn 能非常有效地检视运行中内核的实时数据.

预期结果

先看看 ss -nlpt 会展示什么信息:

图中展示了5条记录, 每条记录包含了监听信息: TCP状态, 队列长度, 地址端口, 进程和fd信息.

数据来源

接下来, 我们需要从内核源码中寻找上述信息的来源.

本文使用的内核版本是 6.2.2, 不同内核版本的数据结构略有不同

内核创建的所有 tcp socket 都保存在 tcp_hashinfo, 而其中的 lhash2数组则保存了所有的监听套接字

(每条链表上的每个struct sock就代表了一个监听端口)

我们需要的监听信息大多可以从struct sock获得.

  • TCP状态

可以从sk.__sk_common.skc_state获得. 不过, drgn 提供了一个现成的 Helper 函数 sk_tcpstate 用于获取 TCP 状态.

def sk_tcpstate(sk: Object) -> Object:
    """
    Return the TCP protocol state of a socket.

    :param sk: ``struct sock *``
    :return: TCP state enum value.
    """
    return cast(sk.prog_["TCP_ESTABLISHED"].type_, sk.__sk_common.skc_state)
  • 队列长度

接收队列长度从 sk.sk_receive_queue.qlen 获得. 发送队列长度从 sk.sk_max_ack_backlog 获得 (LISTEN 套接字)

  • 地址端口

本端地址从 IPv4:sk.__sk_common.skc_rcv_saddr获得 对端地址从 IPv4:sk.__sk_common.skc_daddr获得

本端端口从 inet.inet_sport 获得 对端端口从sk__sk_common.skc_dport获得. 其中inetsk转换而来.

  • 进程和fd信息

这个我找了半天, 只找到以下两条路径:

struct sock –> struct socket –> struct inode struct sock –> struct socket –> strcut file

我原本以为从 strcut filef_owner能拿到进程信息, 但结果并没有成功.

也罢, 只能另寻它路: 扫描 /proc/[PID]/fd 路径.

如果一个 fd 是 socket, 则该目录下的对应文件是一个链接, 同时会展示其 inode 编号

举个例子: PID 1103 的 fd 3:

lrwx------ 1 root root 64 12月 11 20:41 /proc/1103/fd/3 -> 'socket:[23223]'

因此我们可以提前扫描系统中所有进程的所有 fd, 将其中的 socket 对应的 inode 以及 PID 、fd 信息都收集起来.

之后再扫描 tcp_hashinfo 时, 查询套接字的 inode 编号, 从上面已收集信息中找到 PID、fd.

最终结果

script 代码贴在文末

可以看出, 它与 ss -nlpt 几乎一致!

附录 代码

最终代码如下:

#!/usr/bin/env drgn

import ipaddress
import socket
import struct
import os

from drgn import cast, container_of
from drgn.helpers.common.type import enum_type_to_class
from drgn.helpers.linux import (
    SOCK_INODE,
    hlist_nulls_empty,
    hlist_nulls_for_each_entry,
    sk_fullsock,
    sk_nulls_for_each,
    sk_tcpstate,
)


def is_socket_file(file_path):
    try:
        return os.readlink(file_path).startswith('socket:')
    except OSError:
        return False

def build_socket_files_dict():
    socket_files_dict = {}
    for pid in os.listdir('/proc'):
        if pid.isdigit():
            process_name = ""
            try:
                with open(f"/proc/{pid}/comm", "r") as file:
                    process_name = file.read().strip()
            except IOError:
                    process_name = "N/A"

            fd_path = os.path.join('/proc', pid, 'fd')
            if os.path.exists(fd_path):
                for fd_file in os.listdir(fd_path):
                    file_path = os.path.join(fd_path, fd_file)
                    if os.path.islink(file_path) and is_socket_file(file_path):
                        try:
                            inode = os.stat(file_path).st_ino
                            if inode not in socket_files_dict:
                                socket_files_dict[inode] = []
                            socket_files_dict[inode].append({"pid":pid, "name":process_name, "fd": fd_file})
                        except OSError:
                            pass
            else:
                print(f"{fd_path} not exist")
    return socket_files_dict


def get_socket_info_by_inode(socket_info, inode_number):
    if inode_number in socket_info:
        return socket_info[inode_number]
    else:
        return None

TcpState = enum_type_to_class(
    prog["TCP_ESTABLISHED"].type_,
    "TcpState",
    exclude=("TCP_MAX_STATES",),
    prefix="TCP_",
)


def inet_sk(sk):
    return cast("struct inet_sock *", sk)


def _ipv4(be32):
    return ipaddress.IPv4Address(struct.pack("I", be32.value_()))


def _ipv6(in6_addr):
    return ipaddress.IPv6Address(struct.pack("IIII", *in6_addr.in6_u.u6_addr32))


def _brackets(ip):
    if ip.version == 4:
        return "{}".format(ip.compressed)
    elif ip.version == 6:
        return "[{}]".format(ip.compressed)
    return ""

def _ip_port(ip, port):
    return "{:>24}:{:<10}".format(_brackets(ip), port)

def _process_info(socket_info):
    return "\"{}\",pid={}fd={}".format(socket_info[0]["name"], socket_info[0]["pid"], socket_info[0]["fd"])

def _print_sk(sk, socket_dict):
    inet = inet_sk(sk)
    sock = sk.sk_socket
    inode = SOCK_INODE(sock)
   
    socket_info = get_socket_info_by_inode(socket_dict, inode.i_ino.value_())
    if socket_info is None:
        return

    tcp_state = TcpState(sk_tcpstate(sk))
    if sk.__sk_common.skc_family == socket.AF_INET:
        src_ip = _ipv4(sk.__sk_common.skc_rcv_saddr)
        dst_ip = _ipv4(sk.__sk_common.skc_daddr)
    elif sk.__sk_common.skc_family == socket.AF_INET6:
        src_ip = _ipv6(sk.__sk_common.skc_v6_rcv_saddr)
        dst_ip = _ipv6(sk.__sk_common.skc_v6_daddr)
    else:
        return

    recvq = sk.sk_receive_queue.qlen
    sendq = sk.sk_max_ack_backlog
    src_port = socket.ntohs(inet.inet_sport)
    dst_port = socket.ntohs(sk.__sk_common.skc_dport)

    print(
            "{:<16}{:<16}{:<16}{}{}{}".format(
            tcp_state.name,
            recvq.value_(),
            sendq.value_(),
            _ip_port(src_ip, src_port),
            _ip_port(dst_ip, dst_port),
            _process_info(socket_info),
        )
    )

# collect socket 
socket_dict = build_socket_files_dict()

tcp_hashinfo = prog.object("tcp_hashinfo")

print("State           Recv-Q          Send-Q                   Local Address:Port                   Peer Address:Port          Process")
for i in range(tcp_hashinfo.lhash2_mask + 1):
    head = tcp_hashinfo.lhash2[i].nulls_head
    if hlist_nulls_empty(head):
        continue
    for sk in sk_nulls_for_each(head):
        _print_sk(sk, socket_dict)