scapy/test/testsocket.py

191 lines
5.7 KiB
Python

# SPDX-License-Identifier: GPL-2.0-only
# This file is part of Scapy
# See https://scapy.net/ for more information
# Copyright (C) Nils Weiss <nils@we155.de>
# scapy.contrib.description = TestSocket library for unit tests
# scapy.contrib.status = library
import time
import random
from threading import Lock
from scapy.config import conf
from scapy.automaton import ObjectPipe, select_objects
from scapy.data import MTU
from scapy.packet import Packet
from scapy.error import Scapy_Exception
# Typing imports
from typing import (
Optional,
Type,
Tuple,
Any,
List,
)
from scapy.supersocket import SuperSocket
from scapy.plist import (
PacketList,
SndRcvList,
)
open_test_sockets = list() # type: List[TestSocket]
class TestSocket(SuperSocket):
test_socket_mutex = Lock()
def __init__(self,
basecls=None, # type: Optional[Type[Packet]]
external_obj_pipe=None # type: Optional[ObjectPipe[bytes]]
):
# type: (...) -> None
global open_test_sockets
self.basecls = basecls
self.paired_sockets = list() # type: List[TestSocket]
self.ins = external_obj_pipe or ObjectPipe(name="TestSocket") # type: ignore
self._has_external_obj_pip = external_obj_pipe is not None
self.outs = None
open_test_sockets.append(self)
def __enter__(self):
# type: () -> TestSocket
return self
def __exit__(self, exc_type, exc_value, traceback):
# type: (Optional[Type[BaseException]], Optional[BaseException], Optional[Any]) -> None # noqa: E501
"""Close the socket"""
self.close()
def sr(self, *args, **kargs):
# type: (Any, Any) -> Tuple[SndRcvList, PacketList]
"""Send and Receive multiple packets
"""
from scapy import sendrecv
return sendrecv.sndrcv(self, *args, threaded=False, **kargs)
def sr1(self, *args, **kargs):
# type: (Any, Any) -> Optional[Packet]
"""Send one packet and receive one answer
"""
from scapy import sendrecv
ans = sendrecv.sndrcv(self, *args, threaded=False, **kargs)[0] # type: SndRcvList
if len(ans) > 0:
pkt = ans[0][1] # type: Packet
return pkt
else:
return None
def close(self):
# type: () -> None
global open_test_sockets
if self.closed:
return
for s in self.paired_sockets:
try:
s.paired_sockets.remove(self)
except (ValueError, AttributeError, TypeError):
pass
if not self._has_external_obj_pip:
super(TestSocket, self).close()
else:
# We don't close external object pipes
self.closed = True
try:
open_test_sockets.remove(self)
except (ValueError, AttributeError, TypeError):
pass
def pair(self, sock):
# type: (TestSocket) -> None
self.paired_sockets += [sock]
sock.paired_sockets += [self]
def send(self, x):
# type: (Packet) -> int
sx = bytes(x)
for r in self.paired_sockets:
r.ins.send(sx)
try:
x.sent_time = time.time()
except AttributeError:
pass
return len(sx)
def recv_raw(self, x=MTU):
# type: (int) -> Tuple[Optional[Type[Packet]], Optional[bytes], Optional[float]] # noqa: E501
"""Returns a tuple containing (cls, pkt_data, time)"""
return self.basecls, self.ins.recv(0), time.time()
@staticmethod
def select(sockets, remain=conf.recv_poll_rate):
# type: (List[SuperSocket], Optional[float]) -> List[SuperSocket]
return select_objects(sockets, remain)
class UnstableSocket(TestSocket):
"""
This is an unstable socket which randomly fires exceptions or loses
packets on recv.
"""
def __init__(self,
basecls=None, # type: Optional[Type[Packet]]
external_obj_pipe=None # type: Optional[ObjectPipe[bytes]]
):
# type: (...) -> None
super(UnstableSocket, self).__init__(basecls, external_obj_pipe)
self.no_error_for_x_rx_pkts = 10
self.no_error_for_x_tx_pkts = 10
def send(self, x):
# type: (Packet) -> int
if self.no_error_for_x_tx_pkts == 0:
if random.randint(0, 1000) == 42:
self.no_error_for_x_tx_pkts = 10
print("SOCKET CLOSED")
raise OSError("Socket closed")
if self.no_error_for_x_tx_pkts > 0:
self.no_error_for_x_tx_pkts -= 1
return super(UnstableSocket, self).send(x)
def recv(self, x=MTU, **kwargs):
# type: (int, **Any) -> Optional[Packet]
if self.no_error_for_x_tx_pkts == 0:
if random.randint(0, 1000) == 42:
self.no_error_for_x_tx_pkts = 10
raise OSError("Socket closed")
if random.randint(0, 1000) == 13:
self.no_error_for_x_tx_pkts = 10
raise Scapy_Exception("Socket closed")
if random.randint(0, 1000) == 7:
self.no_error_for_x_tx_pkts = 10
raise ValueError("Socket closed")
if random.randint(0, 1000) == 113:
self.no_error_for_x_tx_pkts = 10
return None
if self.no_error_for_x_tx_pkts > 0:
self.no_error_for_x_tx_pkts -= 1
return super(UnstableSocket, self).recv(x, **kwargs)
def cleanup_testsockets():
# type: () -> None
"""
Helper function to remove TestSocket objects after a test
"""
count = max(len(open_test_sockets), 1)
while len(open_test_sockets) and count:
sock = open_test_sockets[0]
sock.close()
count -= 1