I am working on a school project. I set some rules in iptables
which logs INPUT and OUTPUT connections. My goal is to read these logs line by line, parse them and find out which process with which PID is causing this.
My problem starts when I use psutil
to find a match with (ip, port)
tuple with the corresponding PID. iptables
is saving logs to file too fast, like 1x10^-6 seconds
. My Python script also read lines as fast as iptables
. But when I use the following code:
def get_proc(src: str, spt: str, dst: str, dpt: str) -> str:
proc_info = ""
if not (src and spt and dst and dpt):
return proc_info
for proc in psutil.process_iter(["pid", "name"]):
for conn in proc.connections(kind="all"):
if flag.is_set():
return proc_info
if not all([
hasattr(conn.laddr, "ip"), hasattr(conn.laddr, "port"),
hasattr(conn.raddr, "ip"), hasattr(conn.raddr, "port"),
]):
continue
if not all([
conn.laddr.ip == src, conn.laddr.port == int(spt),
conn.raddr.ip == dst, conn.raddr.port == int(dpt),
]):
continue
return f"pid={proc.pid},name={proc.name()}"
return proc_info
psutil
finishes its job like 1x10^-3 seconds
, means 10^3
times slower than reading process. What happens is that: If I run this get_proc
function once
, I read 1000
lines. So this slowness quickly becomes a problem when 1x10^6
lines are read at the end. Because in order to find the PID, I need to run this method immediately when the log is received.
I thought of using multithreading
but as far as I understand it won't solve my problem. Because the same latency problem.
I haven't done much coding so far because I still can't find an algorithm to use. That's way no more code here.
How can I solve this problem with or without multithreading
? Because I can't speed up the execution of psutil
. I believe there must be better approaches.
Edit
Code part for reading logs from iptables.log
:
flag = threading.Event()
def stop(signum, _frame):
"""
Tell everything to stop themselves.
:param signum: The captured signal number.
:param _frame: No use.
"""
if flag.is_set():
return
sys.stderr.write(f"Signal {signum} received.")
flag.set()
signal.signal(signal.SIGINT, stop)
def receive_logs(file, queue__):
global CURSOR_POSITION
with open(file, encoding="utf-8") as _f:
_f.seek(CURSOR_POSITION)
while not flag.is_set():
line = re.sub(r"[\[\]]", "", _f.readline().rstrip())
if not line:
continue
# If all goes okay do some parsing...
# .
# .
queue__.put_nowait((nettup, additional_info))
CURSOR_POSITION = _f.tell()
CodePudding user response:
Here is an approach that may help a bit. As I've mentioned in comments, the issue cannot be entirely avoided unless you change to a better approach entirely.
The idea here is to scan the list of processes not once per connection but for all connections that have arrived since the last scan. Since checking connections can be done with a simple hash table lookup in O(1) time, we can process messages much faster.
I chose to go with a simple 1-producer-1-consumer multithreading approach. I think this will work fine because most time is spent in system calls, so Python's global interpreter lock (GIL) is less of an issue. But that requires testing. Possible variations:
- Use no multithreading, instead read incoming logs nonblocking, then process what you've got
- Swap the threading module and queue for multiprocessing module
- Use multiple consumer threads and maybe batch block sizes to have multiple scans through the process list in parallel
import psutil
import queue
import threading
def receive_logs(consumer_queue):
"""Placeholder for actual code reading iptables log"""
for connection in log:
nettup = (connection.src, int(connection.spt),
connection.dst, int(connection.dpt))
additional_info = connection.additional_info
consumer_queue.put((nettup, additional_info))
The log reading is not part of the posted code, so this is just some placeholder.
Now we consume all queued connections in a second thread:
def get_procs(producer_queue):
# 1. Construct a set of connections to search for
# Blocks until at least one available
nettup, additional_info = producer_queue.get()
connections = {nettup: additional_info}
try: # read as many as possible
while True:
nettup, additional_info = producer_queue.get_nowait()
connections[nettup] = additional_info
except queue.Empty:
pass
found = []
for proc in psutil.process_iter(["pid", "name"]):
for conn in proc.connections(kind="all"):
try:
src = conn.laddr.ip
spt = conn.laddr.port
dst = conn.raddr.ip
dpt = conn.raddr.port
except AttributeError: # not an IP address
continue
nettup = (src, spt, dst, dpt)
if nettup in connections:
additional_info = connections[nettup]
found.append((proc, nettup, additional_info))
found_connections = {nettup for _, nettup, _ in found}
lost = [(nettup, additional_info)
for nettup, additional_info in connections.items()
if not nettup in found_connections]
return found, lost
I don't really understand parts of the posted code in the question, such as the if flag.is_set(): return proc_info
part so I just left those out. Also, I got rid of some of the less pythonic and potentially slow parts such as hasattr()
. Adapt as needed.
Now we tie it all together by calling the consumer repeatedly and starting both threads:
def consume(producer_queue):
while True:
found, lost = get_procs(producer_queue)
for proc, (src, spt, dst, dpt), additional_info in found:
print(f"pid={proc.pid},name={proc.name()}")
def main():
producer_consumer_queue = queue.SimpleQueue()
producer = threading.Thread(
target=receive_logs, args=((producer_consumer_queue, ))
consumer = threading.Thread(
target=consume, args=((producer_consumer_queue, ))
consumer.start()
producer.start()
consumer.join()
producer.join()