# SPDX-License-Identifier: GPL-2.0-only # This file is part of Scapy # See https://scapy.net/ for more information # Copyright (C) Nils Weiss # scapy.contrib.description = TestSocket library for unit tests # scapy.contrib.status = library import time import random from threading import Lock from scapy.config import conf from scapy.automaton import ObjectPipe, select_objects from scapy.data import MTU from scapy.packet import Packet from scapy.error import Scapy_Exception # Typing imports from typing import ( Optional, Type, Tuple, Any, List, ) from scapy.supersocket import SuperSocket from scapy.plist import ( PacketList, SndRcvList, ) open_test_sockets = list() # type: List[TestSocket] class TestSocket(SuperSocket): test_socket_mutex = Lock() def __init__(self, basecls=None, # type: Optional[Type[Packet]] external_obj_pipe=None # type: Optional[ObjectPipe[bytes]] ): # type: (...) -> None global open_test_sockets self.basecls = basecls self.paired_sockets = list() # type: List[TestSocket] self.ins = external_obj_pipe or ObjectPipe(name="TestSocket") # type: ignore self._has_external_obj_pip = external_obj_pipe is not None self.outs = None open_test_sockets.append(self) def __enter__(self): # type: () -> TestSocket return self def __exit__(self, exc_type, exc_value, traceback): # type: (Optional[Type[BaseException]], Optional[BaseException], Optional[Any]) -> None # noqa: E501 """Close the socket""" self.close() def sr(self, *args, **kargs): # type: (Any, Any) -> Tuple[SndRcvList, PacketList] """Send and Receive multiple packets """ from scapy import sendrecv return sendrecv.sndrcv(self, *args, threaded=False, **kargs) def sr1(self, *args, **kargs): # type: (Any, Any) -> Optional[Packet] """Send one packet and receive one answer """ from scapy import sendrecv ans = sendrecv.sndrcv(self, *args, threaded=False, **kargs)[0] # type: SndRcvList if len(ans) > 0: pkt = ans[0][1] # type: Packet return pkt else: return None def close(self): # type: () -> None global open_test_sockets if self.closed: return for s in self.paired_sockets: try: s.paired_sockets.remove(self) except (ValueError, AttributeError, TypeError): pass if not self._has_external_obj_pip: super(TestSocket, self).close() else: # We don't close external object pipes self.closed = True try: open_test_sockets.remove(self) except (ValueError, AttributeError, TypeError): pass def pair(self, sock): # type: (TestSocket) -> None self.paired_sockets += [sock] sock.paired_sockets += [self] def send(self, x): # type: (Packet) -> int sx = bytes(x) for r in self.paired_sockets: r.ins.send(sx) try: x.sent_time = time.time() except AttributeError: pass return len(sx) def recv_raw(self, x=MTU): # type: (int) -> Tuple[Optional[Type[Packet]], Optional[bytes], Optional[float]] # noqa: E501 """Returns a tuple containing (cls, pkt_data, time)""" return self.basecls, self.ins.recv(0), time.time() @staticmethod def select(sockets, remain=conf.recv_poll_rate): # type: (List[SuperSocket], Optional[float]) -> List[SuperSocket] return select_objects(sockets, remain) class UnstableSocket(TestSocket): """ This is an unstable socket which randomly fires exceptions or loses packets on recv. """ def __init__(self, basecls=None, # type: Optional[Type[Packet]] external_obj_pipe=None # type: Optional[ObjectPipe[bytes]] ): # type: (...) -> None super(UnstableSocket, self).__init__(basecls, external_obj_pipe) self.no_error_for_x_rx_pkts = 10 self.no_error_for_x_tx_pkts = 10 def send(self, x): # type: (Packet) -> int if self.no_error_for_x_tx_pkts == 0: if random.randint(0, 1000) == 42: self.no_error_for_x_tx_pkts = 10 print("SOCKET CLOSED") raise OSError("Socket closed") if self.no_error_for_x_tx_pkts > 0: self.no_error_for_x_tx_pkts -= 1 return super(UnstableSocket, self).send(x) def recv(self, x=MTU, **kwargs): # type: (int, **Any) -> Optional[Packet] if self.no_error_for_x_tx_pkts == 0: if random.randint(0, 1000) == 42: self.no_error_for_x_tx_pkts = 10 raise OSError("Socket closed") if random.randint(0, 1000) == 13: self.no_error_for_x_tx_pkts = 10 raise Scapy_Exception("Socket closed") if random.randint(0, 1000) == 7: self.no_error_for_x_tx_pkts = 10 raise ValueError("Socket closed") if random.randint(0, 1000) == 113: self.no_error_for_x_tx_pkts = 10 return None if self.no_error_for_x_tx_pkts > 0: self.no_error_for_x_tx_pkts -= 1 return super(UnstableSocket, self).recv(x, **kwargs) def cleanup_testsockets(): # type: () -> None """ Helper function to remove TestSocket objects after a test """ count = max(len(open_test_sockets), 1) while len(open_test_sockets) and count: sock = open_test_sockets[0] sock.close() count -= 1