26 Class that describes the route for TcpGate or UdpGate.
28 Use `port_for_client == 0` to bind to some unused port. In that case the
29 actual address could be retrieved via BaseGate.get_sockname_for_clients().
31 @ingroup userver_testsuite
37 host_for_client: str =
'127.0.0.1'
38 port_for_client: int = 0
48logger = logging.getLogger(__name__)
51Address = typing.Tuple[str, int]
54Interceptor = typing.Callable[
55 [EvLoop, Socket, Socket],
56 typing.Coroutine[typing.Any, typing.Any,
None],
60class GateException(Exception):
64class GateInterceptException(Exception):
68async def _yield() -> None:
73 await asyncio.sleep(min_delay)
79) -> typing.Tuple[typing.Optional[bytes], typing.Optional[Address]]:
81 return recv_socket.recvfrom(RECV_MAX_SIZE, flags)
82 except (BlockingIOError, InterruptedError):
86async def _get_message_task(
88) -> typing.Tuple[bytes, Address]:
90 msg, addr = _try_get_message(recv_socket, 0)
98def _incoming_data_size(recv_socket: Socket) -> int:
99 msg, _ = _try_get_message(recv_socket, socket.MSG_PEEK)
100 return len(msg)
if msg
else 0
103async def _intercept_ok(
108 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
109 await loop.sock_sendall(socket_to, data)
112async def _intercept_noop(
120async def _intercept_drop(
125 await loop.sock_recv(socket_from, RECV_MAX_SIZE)
128async def _intercept_delay(
134 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
135 await asyncio.sleep(delay)
136 await loop.sock_sendall(socket_to, data)
139async def _intercept_close_on_data(
144 await loop.sock_recv(socket_from, 1)
145 raise GateInterceptException(
'Closing socket on data')
148async def _intercept_corrupt(
153 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
154 await loop.sock_sendall(socket_to, bytearray([
not x
for x
in data]))
157class _InterceptBpsLimit:
158 def __init__(self, bytes_per_second: float):
159 assert bytes_per_second >= 1
160 self._bytes_per_second = bytes_per_second
161 self._time_last_added = 0.0
162 self._bytes_left = self._bytes_per_second
164 def _update_limit(self) -> None:
165 current_time = time.monotonic()
166 elapsed = current_time - self._time_last_added
167 bytes_addition = self._bytes_per_second * elapsed
168 if bytes_addition > 0:
169 self._bytes_left += bytes_addition
170 self._time_last_added = current_time
172 if self._bytes_left > self._bytes_per_second:
173 self._bytes_left = self._bytes_per_second
183 bytes_to_recv = min(int(self._bytes_left), RECV_MAX_SIZE)
184 if bytes_to_recv > 0:
185 data = await loop.sock_recv(socket_from, bytes_to_recv)
186 self._bytes_left -= len(data)
188 await loop.sock_sendall(socket_to, data)
190 logger.info(
'Socket hits the bytes per second limit')
191 await asyncio.sleep(1.0 / self._bytes_per_second)
194class _InterceptTimeLimit:
195 def __init__(self, timeout: float, jitter: float):
196 self._sockets: typing.Dict[Socket, float] = {}
197 assert timeout >= 0.0
198 self._timeout = timeout
200 self._jitter = jitter
202 def raise_if_timed_out(self, socket_from: Socket) ->
None:
203 if socket_from
not in self._sockets:
204 jitter = self._jitter * random.random()
205 expire_at = time.monotonic() + self._timeout + jitter
206 self._sockets[socket_from] = expire_at
208 if self._sockets[socket_from] <= time.monotonic():
209 del self._sockets[socket_from]
210 raise GateInterceptException(
'Socket hits the time limit')
218 self.raise_if_timed_out(socket_from)
219 await _intercept_ok(loop, socket_from, socket_to)
222class _InterceptSmallerParts:
223 def __init__(self, max_size: int, sleep_per_packet: float):
225 self._max_size = max_size
226 self._sleep_per_packet = sleep_per_packet
234 incoming_size = _incoming_data_size(socket_from)
235 chunk_size = min(incoming_size, self._max_size)
236 data = await loop.sock_recv(socket_from, chunk_size)
237 await asyncio.sleep(self._sleep_per_packet)
238 await loop.sock_sendall(socket_to, data)
241class _InterceptConcatPackets:
242 def __init__(self, packet_size: int):
243 assert packet_size >= 0
244 self._packet_size = packet_size
245 self._expire_at: typing.Optional[float] =
None
253 if self._expire_at
is None:
254 self._expire_at = time.monotonic() + MAX_DELAY
256 if self._expire_at <= time.monotonic():
258 f
'Failed to make a packet of sufficient size in {MAX_DELAY} '
259 'seconds. Check the test logic, it should end with checking '
260 'that the data was sent and by calling TcpGate function '
261 'to_client_pass() to pass the remaining packets.',
265 incoming_size = _incoming_data_size(socket_from)
266 if incoming_size >= self._packet_size:
267 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
268 await loop.sock_sendall(socket_to, data)
269 self._expire_at =
None
272class _InterceptBytesLimit:
273 def __init__(self, bytes_limit: int, gate:
'BaseGate'):
274 assert bytes_limit >= 0
275 self._bytes_limit = bytes_limit
276 self._bytes_remain = self._bytes_limit
285 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
286 if self._bytes_remain <= len(data):
287 await loop.sock_sendall(socket_to, data[0 : self._bytes_remain])
288 await self._gate.sockets_close()
289 self._bytes_remain = self._bytes_limit
290 raise GateInterceptException(
'Data transmission limit reached')
292 self._bytes_remain -= len(data)
293 await loop.sock_sendall(socket_to, data)
296class _InterceptSubstitute:
297 def __init__(self, pattern: str, repl: str, encoding=
'utf-8'):
298 self._pattern = re.compile(pattern)
300 self._encoding = encoding
308 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
310 res = self._pattern.sub(self._repl, data.decode(self._encoding))
311 data = res.encode(self._encoding)
314 await loop.sock_sendall(socket_to, data)
317async def _cancel_and_join(task: typing.Optional[asyncio.Task]) ->
None:
318 if not task
or task.cancelled():
324 except asyncio.CancelledError:
326 except Exception
as exc:
327 logger.error(
'Exception in _cancel_and_join: %s', exc)
330def _make_socket_nonblocking(sock: Socket) ->
None:
331 sock.setblocking(
False)
332 if sock.type == socket.SOCK_STREAM:
333 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
334 fcntl.fcntl(sock, fcntl.F_SETFL, os.O_NONBLOCK)
337class _UdpDemuxSocketMock:
339 Emulates a point-to-point connection over UDP socket
340 with a non-blocking socket interface
343 def __init__(self, sock: Socket, peer_address: Address):
344 self._sock: Socket = sock
345 self._peeraddr: Address = peer_address
347 sockpair = socket.socketpair(type=socket.SOCK_DGRAM)
348 self._demux_in: Socket = sockpair[0]
349 self._demux_out: Socket = sockpair[1]
350 _make_socket_nonblocking(self._demux_in)
351 _make_socket_nonblocking(self._demux_out)
352 self._is_active: bool =
True
355 def peer_address(self):
356 return self._peeraddr
358 async def push(self, loop: EvLoop, data: bytes):
359 return await loop.sock_sendall(self._demux_in, data)
362 return self._is_active
365 self._is_active =
False
366 self._demux_out.close()
367 self._demux_in.close()
369 def recvfrom(self, bufsize: int, flags: int = 0):
370 return self._demux_out.recvfrom(bufsize, flags)
372 def recv(self, bufsize: int, flags: int = 0):
373 return self._demux_out.recv(bufsize, flags)
376 return self._demux_out.fileno()
378 def send(self, data: bytes):
379 return self._sock.sendto(data, self._peeraddr)
387 client: typing.Union[socket.socket, _UdpDemuxSocketMock],
388 server: socket.socket,
389 to_server_intercept: Interceptor,
390 to_client_intercept: Interceptor,
392 self._proxy_name = proxy_name
395 self._client = client
396 self._server = server
398 self._to_server_intercept: Interceptor = to_server_intercept
399 self._to_client_intercept: Interceptor = to_client_intercept
401 self._task_to_server = asyncio.create_task(
402 self._do_pipe_channels(to_server=
True),
404 self._task_to_client = asyncio.create_task(
405 self._do_pipe_channels(to_server=
False),
408 self._finished_channels = 0
410 async def _do_pipe_channels(self, *, to_server: bool) ->
None:
412 socket_from = self._client
413 socket_to = self._server
415 socket_from = self._server
416 socket_to = self._client
425 if not _incoming_data_size(socket_from):
430 interceptor = self._to_server_intercept
432 interceptor = self._to_client_intercept
434 await interceptor(self._loop, socket_from, socket_to)
436 except GateInterceptException
as exc:
437 logger.info(
'In "%s": %s', self._proxy_name, exc)
438 except socket.error
as exc:
439 logger.error(
'Exception in "%s": %s', self._proxy_name, exc)
441 self._finished_channels += 1
442 if self._finished_channels == 2:
445 logger.info(
'"%s" closes %s', self._proxy_name, self.info())
446 self._close_socket(self._client)
447 self._close_socket(self._server)
449 assert self._finished_channels == 1
451 self._task_to_client.cancel()
453 self._task_to_server.cancel()
455 def set_to_server_interceptor(self, interceptor: Interceptor) ->
None:
456 self._to_server_intercept = interceptor
458 def set_to_client_interceptor(self, interceptor: Interceptor) ->
None:
459 self._to_client_intercept = interceptor
461 def _close_socket(self, self_socket: Socket) ->
None:
462 assert self_socket
in {self._client, self._server}
465 except socket.error
as exc:
467 'Exception in "%s" on closing %s: %s',
469 'client' if self_socket == self._client
else 'server',
473 async def shutdown(self) -> None:
474 for task
in {self._task_to_client, self._task_to_server}:
475 await _cancel_and_join(task)
477 def is_active(self) -> bool:
478 return not self._task_to_client.done()
or not self._task_to_server.done()
480 def info(self) -> str:
481 if not self.is_active():
484 return f
'client fd={self._client.fileno()} <=> ' f
'server fd={self._server.fileno()}'
841class TcpGate(BaseGate):
843 Implements TCP chaos-proxy logic such as accepting incoming tcp client
844 connections. On each new connection new tcp client connects to server
845 (host_to_server, port_to_server).
847 @ingroup userver_testsuite
849 @see @ref scripts/docs/en/userver/chaos_testing.md
852 def __init__(self, route: GateRoute, loop: EvLoop) -> None:
853 self._connected_event = asyncio.Event()
854 BaseGate.__init__(self, route, loop)
856 def connections_count(self) -> int:
858 Returns maximal amount of connections going through the gate at
861 @warning Some of the connections could be closing, or could be opened
862 right before the function starts. Use with caution!
864 return len(self._sockets)
866 async def wait_for_connections(self, *, count=1, timeout=0.0) -> None:
868 Wait for at least `count` connections going through the gate.
870 @throws asyncio.TimeoutError exception if failed to get the
871 required amount of connections in time.
874 while self.connections_count() < count:
875 await self._connected_event.wait()
876 self._connected_event.clear()
879 deadline = time.monotonic() + timeout
880 while self.connections_count() < count:
881 time_left = deadline - time.monotonic()
882 await asyncio.wait_for(
883 self._connected_event.wait(),
886 self._connected_event.clear()
888 def _create_accepting_sockets(self) -> typing.List[Socket]:
889 res: typing.List[Socket] = []
890 for addr in socket.getaddrinfo(
891 self._route.host_for_client,
892 self._route.port_for_client,
893 type=socket.SOCK_STREAM,
895 sock = Socket(addr[0], addr[1])
896 _make_socket_nonblocking(sock)
897 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
901 f'Accepting connections on {sock.getsockname()}, ' f'fd={sock.fileno()}',
907 async def _connect_to_server(self):
908 addrs = await self._loop.getaddrinfo(
909 self._route.host_to_server,
910 self._route.port_to_server,
911 type=socket.SOCK_STREAM,
914 server = Socket(addr[0], addr[1])
916 _make_socket_nonblocking(server)
917 await self._loop.sock_connect(server, addr[4])
918 logging.trace('Connected to %s', addr[4])
920 except Exception as exc: # pylint: disable=broad-except
922 logging.warning('Could not connect to %s: %s', addr[4], exc)
924 async def _do_accept(self, accept_sock: Socket) -> None:
926 client, _ = await self._loop.sock_accept(accept_sock)
927 _make_socket_nonblocking(client)
929 server = await self._connect_to_server()
937 self._to_server_intercept,
938 self._to_client_intercept,
941 self._connected_event.set()
945 self._collect_garbage()
948class UdpGate(BaseGate):
950 Implements UDP chaos-proxy logic such as demuxing incoming datagrams
951 from different clients.
952 Separate connections to server are made for each new client.
954 @ingroup userver_testsuite
956 @see @ref scripts/docs/en/userver/chaos_testing.md
959 def __init__(self, route: GateRoute, loop: EvLoop):
960 self._clients: typing.Set[_UdpDemuxSocketMock] = set()
961 BaseGate.__init__(self, route, loop)
963 def is_connected(self) -> bool:
965 Returns True if there is active pair of sockets ready to transfer data
968 return len(self._sockets) > 0
970 def _create_accepting_sockets(self) -> typing.List[Socket]:
971 res: typing.List[Socket] = []
972 for addr in socket.getaddrinfo(
973 self._route.host_for_client,
974 self._route.port_for_client,
975 type=socket.SOCK_DGRAM,
977 sock = socket.socket(addr[0], addr[1])
978 _make_socket_nonblocking(sock)
980 logger.debug(f'Accepting connections on {sock.getsockname()}')
985 async def _connect_to_server(self):
986 addrs = await self._loop.getaddrinfo(
987 self._route.host_to_server,
988 self._route.port_to_server,
989 type=socket.SOCK_DGRAM,
992 server = Socket(addr[0], addr[1])
994 _make_socket_nonblocking(server)
995 await self._loop.sock_connect(server, addr[4])
996 logging.trace('Connected to %s', addr[4])
998 except Exception as exc: # pylint: disable=broad-except
999 logging.warning('Could not connect to %s: %s', addr[4], exc)
1001 def _collect_garbage(self) -> None:
1002 super()._collect_garbage()
1003 self._clients = {c for c in self._clients if c.is_active()}
1005 async def _do_accept(self, accept_sock: Socket):
1007 data, addr = await _get_message_task(accept_sock)
1009 client: typing.Optional[_UdpDemuxSocketMock] = None
1010 for known_clients in self._clients:
1011 if addr == known_clients.peer_address:
1012 client = known_clients
1016 server = await self._connect_to_server()
1021 client = _UdpDemuxSocketMock(accept_sock, addr)
1022 self._clients.add(client)
1030 self._to_server_intercept,
1031 self._to_client_intercept,
1035 await client.push(self._loop, data)
1036 self._collect_garbage()
1038 def to_server_concat_packets(self, packet_size: int) -> None:
1039 raise NotImplementedError('Udp packets cannot be concatenated')
1041 def to_client_concat_packets(self, packet_size: int) -> None:
1042 raise NotImplementedError('Udp packets cannot be concatenated')
1044 def to_server_smaller_parts(
1048 sleep_per_packet: float = 0,
1050 raise NotImplementedError('Udp packets cannot be split')
1052 def to_client_smaller_parts(
1056 sleep_per_packet: float = 0,
1058 raise NotImplementedError('Udp packets cannot be split')