43 Class that describes the route for TcpGate or UdpGate.
45 Use `port_for_client == 0` to bind to some unused port. In that case the
46 actual address could be retrieved via BaseGate.get_sockname_for_clients().
48 @ingroup userver_testsuite
54 host_for_client: str =
'127.0.0.1'
55 port_for_client: int = 0
65logger = logging.getLogger(__name__)
68Address: TypeAlias = tuple[str, int]
69EvLoop: TypeAlias = Any
70Socket: TypeAlias = socket.socket
71Interceptor: TypeAlias = Callable[[EvLoop, Socket, Socket], Coroutine[Any, Any,
None]]
74class GateException(Exception):
78class GateInterceptException(Exception):
82async def _intercept_ok(
87 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
89 raise ConnectionClosedError()
90 await loop.sock_sendall(socket_to, data)
93async def _intercept_drop(
98 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
100 raise ConnectionClosedError()
103async def _intercept_delay(
109 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
111 raise ConnectionClosedError()
112 await asyncio.sleep(delay)
113 await loop.sock_sendall(socket_to, data)
116async def _intercept_close_on_data(
121 data = await loop.sock_recv(socket_from, 1)
123 raise ConnectionClosedError()
124 raise GateInterceptException(
'Closing socket on data')
127async def _intercept_corrupt(
132 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
134 raise ConnectionClosedError()
135 await loop.sock_sendall(socket_to, bytearray([
not x
for x
in data]))
138class _InterceptBpsLimit:
139 def __init__(self, bytes_per_second: float):
140 assert bytes_per_second >= 1
141 self._bytes_per_second = bytes_per_second
142 self._time_last_added = 0.0
143 self._bytes_left = self._bytes_per_second
145 def _update_limit(self) -> None:
146 current_time = time.monotonic()
147 elapsed = current_time - self._time_last_added
148 bytes_addition = self._bytes_per_second * elapsed
149 if bytes_addition > 0:
150 self._bytes_left += bytes_addition
151 self._time_last_added = current_time
153 if self._bytes_left > self._bytes_per_second:
154 self._bytes_left = self._bytes_per_second
164 bytes_to_recv = min(int(self._bytes_left), RECV_MAX_SIZE)
165 if bytes_to_recv > 0:
166 data = await loop.sock_recv(socket_from, bytes_to_recv)
168 raise ConnectionClosedError()
169 self._bytes_left -= len(data)
171 await loop.sock_sendall(socket_to, data)
173 logger.info(
'Socket hits the bytes per second limit')
174 await asyncio.sleep(1.0 / self._bytes_per_second)
177class _InterceptTimeLimit:
178 def __init__(self, timeout: float, jitter: float):
179 self._sockets: dict[Socket, float] = {}
180 assert timeout >= 0.0
181 self._timeout = timeout
183 self._jitter = jitter
185 def raise_if_timed_out(self, socket_from: Socket) ->
None:
186 if socket_from
not in self._sockets:
187 jitter = self._jitter * random.random()
188 expire_at = time.monotonic() + self._timeout + jitter
189 self._sockets[socket_from] = expire_at
191 if self._sockets[socket_from] <= time.monotonic():
192 del self._sockets[socket_from]
193 raise GateInterceptException(
'Socket hits the time limit')
201 self.raise_if_timed_out(socket_from)
202 await _intercept_ok(loop, socket_from, socket_to)
205class _InterceptSmallerParts:
206 def __init__(self, max_size: int, sleep_per_packet: float):
208 self._max_size = max_size
209 self._sleep_per_packet = sleep_per_packet
217 data = await loop.sock_recv(socket_from, self._max_size)
219 raise ConnectionClosedError()
220 await asyncio.sleep(self._sleep_per_packet)
221 await loop.sock_sendall(socket_to, data)
224class _InterceptConcatPackets:
225 def __init__(self, packet_size: int):
226 assert packet_size >= 0
227 self._packet_size = packet_size
228 self._expire_at: float |
None =
None
229 self._buf = io.BytesIO()
237 if self._expire_at
is None:
238 self._expire_at = time.monotonic() + MAX_DELAY
240 if self._expire_at <= time.monotonic():
242 f
'Failed to make a packet of sufficient size in {MAX_DELAY} '
243 'seconds. Check the test logic, it should end with checking '
244 'that the data was sent and by calling TcpGate function '
245 'to_client_pass() to pass the remaining packets.',
248 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
250 raise ConnectionClosedError()
251 self._buf.write(data)
252 if self._buf.tell() >= self._packet_size:
253 await loop.sock_sendall(socket_to, self._buf.getvalue())
254 self._buf = io.BytesIO()
255 self._expire_at =
None
258class _InterceptBytesLimit:
259 def __init__(self, bytes_limit: int, gate: BaseGate):
260 assert bytes_limit >= 0
261 self._bytes_limit = bytes_limit
262 self._bytes_remain = self._bytes_limit
271 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
273 raise ConnectionClosedError()
274 if self._bytes_remain <= len(data):
275 await loop.sock_sendall(socket_to, data[0 : self._bytes_remain])
276 await self._gate.sockets_close()
277 self._bytes_remain = self._bytes_limit
278 raise GateInterceptException(
'Data transmission limit reached')
279 self._bytes_remain -= len(data)
280 await loop.sock_sendall(socket_to, data)
283class _InterceptSubstitute:
284 def __init__(self, pattern: str, repl: str, encoding=
'utf-8'):
285 self._pattern = re.compile(pattern)
287 self._encoding = encoding
295 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
297 raise ConnectionClosedError()
299 res = self._pattern.sub(self._repl, data.decode(self._encoding))
300 data = res.encode(self._encoding)
303 await loop.sock_sendall(socket_to, data)
306async def _cancel_and_join(task: asyncio.Task |
None) ->
None:
307 if not task
or task.cancelled():
313 except asyncio.CancelledError:
316 logger.exception(
'Exception in _cancel_and_join')
319def _make_socket_nonblocking(sock: Socket) ->
None:
320 sock.setblocking(
False)
321 if sock.type == socket.SOCK_STREAM:
322 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
325class _UdpDemuxSocketMock:
327 Emulates a point-to-point connection over UDP socket
328 with a non-blocking socket interface
331 def gettimeout(self):
332 return self._sock.gettimeout()
334 def __init__(self, sock: Socket, peer_address: Address):
335 self._sock: Socket = sock
336 self._peeraddr: Address = peer_address
338 sockpair = socket.socketpair(type=socket.SOCK_DGRAM)
339 self._demux_in: Socket = sockpair[0]
340 self._demux_out: Socket = sockpair[1]
341 _make_socket_nonblocking(self._demux_in)
342 _make_socket_nonblocking(self._demux_out)
343 self._is_active: bool =
True
346 def peer_address(self):
347 return self._peeraddr
349 async def push(self, loop: EvLoop, data: bytes):
350 return await loop.sock_sendall(self._demux_in, data)
353 return self._is_active
356 self._is_active =
False
357 self._demux_out.close()
358 self._demux_in.close()
360 def recvfrom(self, bufsize: int, flags: int = 0):
361 return self._demux_out.recvfrom(bufsize, flags)
363 def recv(self, bufsize: int, flags: int = 0):
364 return self._demux_out.recv(bufsize, flags)
366 def get_demux_out(self):
367 return self._demux_out
370 return self._demux_out.fileno()
372 def send(self, data: bytes):
373 return self._sock.sendto(data, self._peeraddr)
377 def __init__(self, socket_from, socket_to, interceptor):
378 self._socket_from = socket_from
379 self._socket_to = socket_to
380 self._condition = asyncio.Condition()
381 self._interceptor = interceptor
383 def get_interceptor(self):
384 return self._interceptor
386 async def set_interceptor(self, interceptor):
387 async with self._condition:
388 self._interceptor = interceptor
389 self._condition.notify()
392 loop = asyncio.get_running_loop()
399 await _wait_for_data(self._socket_from)
402 async with self._condition:
403 interceptor = await self._condition.wait_for(self.get_interceptor)
405 logging.trace(
'running interceptor: %s', interceptor)
406 await interceptor(loop, self._socket_from, self._socket_to)
414 client: socket.socket | _UdpDemuxSocketMock,
415 server: socket.socket,
416 to_server_intercept: Interceptor,
417 to_client_intercept: Interceptor,
419 self._proxy_name = proxy_name
421 self._client = client
422 self._server = server
424 self._task_to_server = InterceptTask(client, server, to_server_intercept)
425 self._task_to_client = InterceptTask(server, client, to_client_intercept)
427 self._task = asyncio.create_task(self._run())
428 self._interceptor_tasks = []
430 async def set_to_server_interceptor(self, interceptor):
431 await self._task_to_server.set_interceptor(interceptor)
433 async def set_to_client_interceptor(self, interceptor: Interceptor):
434 await self._task_to_client.set_interceptor(interceptor)
436 async def shutdown(self) -> None:
437 for task
in self._interceptor_tasks:
438 await _cancel_and_join(task)
439 await _cancel_and_join(self._task)
441 def is_active(self) -> bool:
442 return not self._task.done()
444 def info(self) -> str:
445 if not self.is_active():
448 return f
'client fd={self._client.fileno()} <=> server fd={self._server.fileno()}'
450 async def _run(self):
451 self._interceptor_tasks = [
452 asyncio.create_task(obj.run())
for obj
in (self._task_to_server, self._task_to_client)
455 done, _ = await asyncio.wait(self._interceptor_tasks, return_when=asyncio.FIRST_EXCEPTION)
458 except GateInterceptException
as exc:
459 logger.info(
'In "%s": %s', self._proxy_name, exc)
460 except OSError
as exc:
461 logger.error(
'Exception in "%s": %s', self._proxy_name, exc)
463 logger.exception(
'interceptor failed')
465 for task
in self._interceptor_tasks:
470 for sock
in self._server, self._client:
474 logger.exception(
'Exception in "%s" on closing %s:', self._proxy_name, sock)
839class TcpGate(BaseGate):
841 Implements TCP chaos-proxy logic such as accepting incoming tcp client
842 connections. On each new connection new tcp client connects to server
843 (host_to_server, port_to_server).
845 @ingroup userver_testsuite
847 @see @ref scripts/docs/en/userver/chaos_testing.md
850 def __init__(self, route: GateRoute, loop: EvLoop | None = None) -> None:
851 self._connected_event = asyncio.Event()
852 super().__init__(route, loop)
854 def connections_count(self) -> int:
856 Returns maximal amount of connections going through the gate at
859 @warning Some of the connections could be closing, or could be opened
860 right before the function starts. Use with caution!
862 return len(self._sockets)
864 async def wait_for_connections(self, *, count=1, timeout=0.0) -> None:
866 Wait for at least `count` connections going through the gate.
868 @throws asyncio.TimeoutError exception if failed to get the
869 required amount of connections in time.
872 while self.connections_count() < count:
873 await self._connected_event.wait()
874 self._connected_event.clear()
877 deadline = time.monotonic() + timeout
878 while self.connections_count() < count:
879 time_left = deadline - time.monotonic()
880 await asyncio.wait_for(
881 self._connected_event.wait(),
884 self._connected_event.clear()
886 def _create_accepting_sockets(self) -> list[Socket]:
887 res: list[Socket] = []
888 for addr in socket.getaddrinfo(
889 self._route.host_for_client,
890 self._route.port_for_client,
891 type=socket.SOCK_STREAM,
893 sock = Socket(addr[0], addr[1])
894 _make_socket_nonblocking(sock)
895 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
899 f'Accepting connections on {sock.getsockname()}, fd={sock.fileno()}',
905 async def _connect_to_server(self):
906 addrs = await self._loop.getaddrinfo(
907 self._route.host_to_server,
908 self._route.port_to_server,
909 type=socket.SOCK_STREAM,
912 server = Socket(addr[0], addr[1])
913 _make_socket_nonblocking(server)
915 await self._loop.sock_connect(server, addr[4])
916 logging.trace('Connected to %s', addr[4])
918 except Exception as exc: # pylint: disable=broad-except
920 logging.warning('Could not connect to %s: %s', addr[4], exc)
922 async def _do_accept(self, accept_sock: Socket) -> None:
924 client, _ = await self._loop.sock_accept(accept_sock)
925 _make_socket_nonblocking(client)
927 server = await self._connect_to_server()
935 self._to_server_intercept,
936 self._to_client_intercept,
939 self._connected_event.set()
943 self._collect_garbage()
946class UdpGate(BaseGate):
948 Implements UDP chaos-proxy logic such as demuxing incoming datagrams
949 from different clients.
950 Separate connections to server are made for each new client.
952 @ingroup userver_testsuite
954 @see @ref scripts/docs/en/userver/chaos_testing.md
957 def __init__(self, route: GateRoute, loop: EvLoop | None = None):
958 self._clients: set[_UdpDemuxSocketMock] = set()
959 super().__init__(route, loop)
961 def is_connected(self) -> bool:
963 Returns True if there is active pair of sockets ready to transfer data
966 return len(self._sockets) > 0
968 def _create_accepting_sockets(self) -> list[Socket]:
969 res: list[Socket] = []
970 for addr in socket.getaddrinfo(
971 self._route.host_for_client,
972 self._route.port_for_client,
973 type=socket.SOCK_DGRAM,
975 sock = socket.socket(addr[0], addr[1])
976 _make_socket_nonblocking(sock)
978 logger.debug(f'Accepting connections on {sock.getsockname()}')
983 async def _connect_to_server(self):
984 addrs = await self._loop.getaddrinfo(
985 self._route.host_to_server,
986 self._route.port_to_server,
987 type=socket.SOCK_DGRAM,
990 server = Socket(addr[0], addr[1])
992 _make_socket_nonblocking(server)
993 await self._loop.sock_connect(server, addr[4])
994 logging.trace('Connected to %s', addr[4])
996 except Exception as exc: # pylint: disable=broad-except
997 logging.warning('Could not connect to %s: %s', addr[4], exc)
999 def _collect_garbage(self) -> None:
1000 super()._collect_garbage()
1001 self._clients = {c for c in self._clients if c.is_active()}
1003 async def _do_accept(self, accept_sock: Socket):
1004 sock = asyncio_socket.from_socket(accept_sock)
1006 data, addr = await sock.recvfrom(RECV_MAX_SIZE, timeout=60.0)
1008 client: _UdpDemuxSocketMock | None = None
1009 for known_clients in self._clients:
1010 if addr == known_clients.peer_address:
1011 client = known_clients
1015 server = await self._connect_to_server()
1020 client = _UdpDemuxSocketMock(accept_sock, addr)
1021 self._clients.add(client)
1029 self._to_server_intercept,
1030 self._to_client_intercept,
1034 await client.push(self._loop, data)
1035 self._collect_garbage()
1037 async def to_server_concat_packets(self, packet_size: int) -> None:
1038 raise NotImplementedError('Udp packets cannot be concatenated')
1040 async def to_client_concat_packets(self, packet_size: int) -> None:
1041 raise NotImplementedError('Udp packets cannot be concatenated')
1043 async def to_server_smaller_parts(
1047 sleep_per_packet: float = 0,
1049 raise NotImplementedError('Udp packets cannot be split')
1051 async def to_client_smaller_parts(
1055 sleep_per_packet: float = 0,
1057 raise NotImplementedError('Udp packets cannot be split')