mirror of https://github.com/secdev/scapy.git
191 lines
5.7 KiB
Python
191 lines
5.7 KiB
Python
# SPDX-License-Identifier: GPL-2.0-only
|
|
# This file is part of Scapy
|
|
# See https://scapy.net/ for more information
|
|
# Copyright (C) Nils Weiss <nils@we155.de>
|
|
|
|
# scapy.contrib.description = TestSocket library for unit tests
|
|
# scapy.contrib.status = library
|
|
|
|
import time
|
|
import random
|
|
|
|
from threading import Lock
|
|
|
|
from scapy.config import conf
|
|
from scapy.automaton import ObjectPipe, select_objects
|
|
from scapy.data import MTU
|
|
from scapy.packet import Packet
|
|
from scapy.error import Scapy_Exception
|
|
|
|
# Typing imports
|
|
from typing import (
|
|
Optional,
|
|
Type,
|
|
Tuple,
|
|
Any,
|
|
List,
|
|
)
|
|
from scapy.supersocket import SuperSocket
|
|
|
|
from scapy.plist import (
|
|
PacketList,
|
|
SndRcvList,
|
|
)
|
|
|
|
|
|
open_test_sockets = list() # type: List[TestSocket]
|
|
|
|
|
|
class TestSocket(SuperSocket):
|
|
|
|
test_socket_mutex = Lock()
|
|
|
|
def __init__(self,
|
|
basecls=None, # type: Optional[Type[Packet]]
|
|
external_obj_pipe=None # type: Optional[ObjectPipe[bytes]]
|
|
):
|
|
# type: (...) -> None
|
|
global open_test_sockets
|
|
self.basecls = basecls
|
|
self.paired_sockets = list() # type: List[TestSocket]
|
|
self.ins = external_obj_pipe or ObjectPipe(name="TestSocket") # type: ignore
|
|
self._has_external_obj_pip = external_obj_pipe is not None
|
|
self.outs = None
|
|
open_test_sockets.append(self)
|
|
|
|
def __enter__(self):
|
|
# type: () -> TestSocket
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
# type: (Optional[Type[BaseException]], Optional[BaseException], Optional[Any]) -> None # noqa: E501
|
|
"""Close the socket"""
|
|
self.close()
|
|
|
|
def sr(self, *args, **kargs):
|
|
# type: (Any, Any) -> Tuple[SndRcvList, PacketList]
|
|
"""Send and Receive multiple packets
|
|
"""
|
|
from scapy import sendrecv
|
|
return sendrecv.sndrcv(self, *args, threaded=False, **kargs)
|
|
|
|
def sr1(self, *args, **kargs):
|
|
# type: (Any, Any) -> Optional[Packet]
|
|
"""Send one packet and receive one answer
|
|
"""
|
|
from scapy import sendrecv
|
|
ans = sendrecv.sndrcv(self, *args, threaded=False, **kargs)[0] # type: SndRcvList
|
|
if len(ans) > 0:
|
|
pkt = ans[0][1] # type: Packet
|
|
return pkt
|
|
else:
|
|
return None
|
|
|
|
def close(self):
|
|
# type: () -> None
|
|
global open_test_sockets
|
|
|
|
if self.closed:
|
|
return
|
|
|
|
for s in self.paired_sockets:
|
|
try:
|
|
s.paired_sockets.remove(self)
|
|
except (ValueError, AttributeError, TypeError):
|
|
pass
|
|
|
|
if not self._has_external_obj_pip:
|
|
super(TestSocket, self).close()
|
|
else:
|
|
# We don't close external object pipes
|
|
self.closed = True
|
|
|
|
try:
|
|
open_test_sockets.remove(self)
|
|
except (ValueError, AttributeError, TypeError):
|
|
pass
|
|
|
|
def pair(self, sock):
|
|
# type: (TestSocket) -> None
|
|
self.paired_sockets += [sock]
|
|
sock.paired_sockets += [self]
|
|
|
|
def send(self, x):
|
|
# type: (Packet) -> int
|
|
sx = bytes(x)
|
|
for r in self.paired_sockets:
|
|
r.ins.send(sx)
|
|
try:
|
|
x.sent_time = time.time()
|
|
except AttributeError:
|
|
pass
|
|
return len(sx)
|
|
|
|
def recv_raw(self, x=MTU):
|
|
# type: (int) -> Tuple[Optional[Type[Packet]], Optional[bytes], Optional[float]] # noqa: E501
|
|
"""Returns a tuple containing (cls, pkt_data, time)"""
|
|
return self.basecls, self.ins.recv(0), time.time()
|
|
|
|
@staticmethod
|
|
def select(sockets, remain=conf.recv_poll_rate):
|
|
# type: (List[SuperSocket], Optional[float]) -> List[SuperSocket]
|
|
return select_objects(sockets, remain)
|
|
|
|
|
|
class UnstableSocket(TestSocket):
|
|
"""
|
|
This is an unstable socket which randomly fires exceptions or loses
|
|
packets on recv.
|
|
"""
|
|
|
|
def __init__(self,
|
|
basecls=None, # type: Optional[Type[Packet]]
|
|
external_obj_pipe=None # type: Optional[ObjectPipe[bytes]]
|
|
):
|
|
# type: (...) -> None
|
|
super(UnstableSocket, self).__init__(basecls, external_obj_pipe)
|
|
self.no_error_for_x_rx_pkts = 10
|
|
self.no_error_for_x_tx_pkts = 10
|
|
|
|
def send(self, x):
|
|
# type: (Packet) -> int
|
|
if self.no_error_for_x_tx_pkts == 0:
|
|
if random.randint(0, 1000) == 42:
|
|
self.no_error_for_x_tx_pkts = 10
|
|
print("SOCKET CLOSED")
|
|
raise OSError("Socket closed")
|
|
if self.no_error_for_x_tx_pkts > 0:
|
|
self.no_error_for_x_tx_pkts -= 1
|
|
return super(UnstableSocket, self).send(x)
|
|
|
|
def recv(self, x=MTU, **kwargs):
|
|
# type: (int, **Any) -> Optional[Packet]
|
|
if self.no_error_for_x_tx_pkts == 0:
|
|
if random.randint(0, 1000) == 42:
|
|
self.no_error_for_x_tx_pkts = 10
|
|
raise OSError("Socket closed")
|
|
if random.randint(0, 1000) == 13:
|
|
self.no_error_for_x_tx_pkts = 10
|
|
raise Scapy_Exception("Socket closed")
|
|
if random.randint(0, 1000) == 7:
|
|
self.no_error_for_x_tx_pkts = 10
|
|
raise ValueError("Socket closed")
|
|
if random.randint(0, 1000) == 113:
|
|
self.no_error_for_x_tx_pkts = 10
|
|
return None
|
|
if self.no_error_for_x_tx_pkts > 0:
|
|
self.no_error_for_x_tx_pkts -= 1
|
|
return super(UnstableSocket, self).recv(x, **kwargs)
|
|
|
|
|
|
def cleanup_testsockets():
|
|
# type: () -> None
|
|
"""
|
|
Helper function to remove TestSocket objects after a test
|
|
"""
|
|
count = max(len(open_test_sockets), 1)
|
|
while len(open_test_sockets) and count:
|
|
sock = open_test_sockets[0]
|
|
sock.close()
|
|
count -= 1
|