跟我一起学习drgn(4)---实现'ss -nlpt'
ss -nlpt
是查看本机tcp监听端口的常用命令, 而本文将使用drgn
重新实现它. 目的是为了证明一点: 作为一种快照式的调试工具, drgn
能非常有效地检视运行中内核的实时数据.
预期结果
先看看 ss -nlpt
会展示什么信息:
图中展示了5条记录, 每条记录包含了监听信息: TCP状态, 队列长度, 地址端口, 进程和fd信息.
数据来源
接下来, 我们需要从内核源码中寻找上述信息的来源.
本文使用的内核版本是 6.2.2, 不同内核版本的数据结构略有不同
内核创建的所有 tcp socket 都保存在 tcp_hashinfo, 而其中的 lhash2数组则保存了所有的监听套接字
(每条链表上的每个struct sock
就代表了一个监听端口)
我们需要的监听信息大多可以从struct sock
获得.
- TCP状态
可以从sk.__sk_common.skc_state
获得. 不过, drgn
提供了一个现成的 Helper 函数 sk_tcpstate
用于获取 TCP 状态.
def sk_tcpstate(sk: Object) -> Object:
"""
Return the TCP protocol state of a socket.
:param sk: ``struct sock *``
:return: TCP state enum value.
"""
return cast(sk.prog_["TCP_ESTABLISHED"].type_, sk.__sk_common.skc_state)
- 队列长度
接收队列长度从 sk.sk_receive_queue.qlen
获得.
发送队列长度从 sk.sk_max_ack_backlog
获得 (LISTEN 套接字)
- 地址端口
本端地址从 IPv4:sk.__sk_common.skc_rcv_saddr
获得
对端地址从 IPv4:sk.__sk_common.skc_daddr
获得
本端端口从 inet.inet_sport
获得
对端端口从sk__sk_common.skc_dport
获得. 其中inet
由sk
转换而来.
- 进程和fd信息
这个我找了半天, 只找到以下两条路径:
struct sock
–> struct socket
–> struct inode
struct sock
–> struct socket
–> strcut file
我原本以为从 strcut file
的 f_owner能拿到进程信息, 但结果并没有成功.
也罢, 只能另寻它路: 扫描 /proc/[PID]/fd
路径.
如果一个 fd 是 socket, 则该目录下的对应文件是一个链接, 同时会展示其 inode 编号
举个例子: PID 1103 的 fd 3:
lrwx------ 1 root root 64 12月 11 20:41 /proc/1103/fd/3 -> 'socket:[23223]'
因此我们可以提前扫描系统中所有进程的所有 fd, 将其中的 socket 对应的 inode 以及 PID 、fd 信息都收集起来.
之后再扫描 tcp_hashinfo
时, 查询套接字的 inode 编号, 从上面已收集信息中找到 PID、fd.
最终结果
script 代码贴在文末
可以看出, 它与 ss -nlpt
几乎一致!
附录 代码
最终代码如下:
#!/usr/bin/env drgn
import ipaddress
import socket
import struct
import os
from drgn import cast, container_of
from drgn.helpers.common.type import enum_type_to_class
from drgn.helpers.linux import (
SOCK_INODE,
hlist_nulls_empty,
hlist_nulls_for_each_entry,
sk_fullsock,
sk_nulls_for_each,
sk_tcpstate,
)
def is_socket_file(file_path):
try:
return os.readlink(file_path).startswith('socket:')
except OSError:
return False
def build_socket_files_dict():
socket_files_dict = {}
for pid in os.listdir('/proc'):
if pid.isdigit():
process_name = ""
try:
with open(f"/proc/{pid}/comm", "r") as file:
process_name = file.read().strip()
except IOError:
process_name = "N/A"
fd_path = os.path.join('/proc', pid, 'fd')
if os.path.exists(fd_path):
for fd_file in os.listdir(fd_path):
file_path = os.path.join(fd_path, fd_file)
if os.path.islink(file_path) and is_socket_file(file_path):
try:
inode = os.stat(file_path).st_ino
if inode not in socket_files_dict:
socket_files_dict[inode] = []
socket_files_dict[inode].append({"pid":pid, "name":process_name, "fd": fd_file})
except OSError:
pass
else:
print(f"{fd_path} not exist")
return socket_files_dict
def get_socket_info_by_inode(socket_info, inode_number):
if inode_number in socket_info:
return socket_info[inode_number]
else:
return None
TcpState = enum_type_to_class(
prog["TCP_ESTABLISHED"].type_,
"TcpState",
exclude=("TCP_MAX_STATES",),
prefix="TCP_",
)
def inet_sk(sk):
return cast("struct inet_sock *", sk)
def _ipv4(be32):
return ipaddress.IPv4Address(struct.pack("I", be32.value_()))
def _ipv6(in6_addr):
return ipaddress.IPv6Address(struct.pack("IIII", *in6_addr.in6_u.u6_addr32))
def _brackets(ip):
if ip.version == 4:
return "{}".format(ip.compressed)
elif ip.version == 6:
return "[{}]".format(ip.compressed)
return ""
def _ip_port(ip, port):
return "{:>24}:{:<10}".format(_brackets(ip), port)
def _process_info(socket_info):
return "\"{}\",pid={}fd={}".format(socket_info[0]["name"], socket_info[0]["pid"], socket_info[0]["fd"])
def _print_sk(sk, socket_dict):
inet = inet_sk(sk)
sock = sk.sk_socket
inode = SOCK_INODE(sock)
socket_info = get_socket_info_by_inode(socket_dict, inode.i_ino.value_())
if socket_info is None:
return
tcp_state = TcpState(sk_tcpstate(sk))
if sk.__sk_common.skc_family == socket.AF_INET:
src_ip = _ipv4(sk.__sk_common.skc_rcv_saddr)
dst_ip = _ipv4(sk.__sk_common.skc_daddr)
elif sk.__sk_common.skc_family == socket.AF_INET6:
src_ip = _ipv6(sk.__sk_common.skc_v6_rcv_saddr)
dst_ip = _ipv6(sk.__sk_common.skc_v6_daddr)
else:
return
recvq = sk.sk_receive_queue.qlen
sendq = sk.sk_max_ack_backlog
src_port = socket.ntohs(inet.inet_sport)
dst_port = socket.ntohs(sk.__sk_common.skc_dport)
print(
"{:<16}{:<16}{:<16}{}{}{}".format(
tcp_state.name,
recvq.value_(),
sendq.value_(),
_ip_port(src_ip, src_port),
_ip_port(dst_ip, dst_port),
_process_info(socket_info),
)
)
# collect socket
socket_dict = build_socket_files_dict()
tcp_hashinfo = prog.object("tcp_hashinfo")
print("State Recv-Q Send-Q Local Address:Port Peer Address:Port Process")
for i in range(tcp_hashinfo.lhash2_mask + 1):
head = tcp_hashinfo.lhash2[i].nulls_head
if hlist_nulls_empty(head):
continue
for sk in sk_nulls_for_each(head):
_print_sk(sk, socket_dict)