38 Class that describes the route for TcpGate or UdpGate.
40 Use `port_for_client == 0` to bind to some unused port. In that case the
41 actual address could be retrieved via BaseGate.get_sockname_for_clients().
43 @ingroup userver_testsuite
49 host_for_client: str =
'127.0.0.1'
50 port_for_client: int = 0
60logger = logging.getLogger(__name__)
63Address = typing.Tuple[str, int]
66Interceptor = typing.Callable[
67 [EvLoop, Socket, Socket],
68 typing.Coroutine[typing.Any, typing.Any,
None],
72class GateException(Exception):
76class GateInterceptException(Exception):
80async def _intercept_ok(
85 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
87 raise ConnectionClosedError()
88 await loop.sock_sendall(socket_to, data)
91async def _intercept_drop(
96 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
98 raise ConnectionClosedError()
101async def _intercept_delay(
107 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
109 raise ConnectionClosedError()
110 await asyncio.sleep(delay)
111 await loop.sock_sendall(socket_to, data)
114async def _intercept_close_on_data(
119 data = await loop.sock_recv(socket_from, 1)
121 raise ConnectionClosedError()
122 raise GateInterceptException(
'Closing socket on data')
125async def _intercept_corrupt(
130 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
132 raise ConnectionClosedError()
133 await loop.sock_sendall(socket_to, bytearray([
not x
for x
in data]))
136class _InterceptBpsLimit:
137 def __init__(self, bytes_per_second: float):
138 assert bytes_per_second >= 1
139 self._bytes_per_second = bytes_per_second
140 self._time_last_added = 0.0
141 self._bytes_left = self._bytes_per_second
143 def _update_limit(self) -> None:
144 current_time = time.monotonic()
145 elapsed = current_time - self._time_last_added
146 bytes_addition = self._bytes_per_second * elapsed
147 if bytes_addition > 0:
148 self._bytes_left += bytes_addition
149 self._time_last_added = current_time
151 if self._bytes_left > self._bytes_per_second:
152 self._bytes_left = self._bytes_per_second
162 bytes_to_recv = min(int(self._bytes_left), RECV_MAX_SIZE)
163 if bytes_to_recv > 0:
164 data = await loop.sock_recv(socket_from, bytes_to_recv)
166 raise ConnectionClosedError()
167 self._bytes_left -= len(data)
169 await loop.sock_sendall(socket_to, data)
171 logger.info(
'Socket hits the bytes per second limit')
172 await asyncio.sleep(1.0 / self._bytes_per_second)
175class _InterceptTimeLimit:
176 def __init__(self, timeout: float, jitter: float):
177 self._sockets: typing.Dict[Socket, float] = {}
178 assert timeout >= 0.0
179 self._timeout = timeout
181 self._jitter = jitter
183 def raise_if_timed_out(self, socket_from: Socket) ->
None:
184 if socket_from
not in self._sockets:
185 jitter = self._jitter * random.random()
186 expire_at = time.monotonic() + self._timeout + jitter
187 self._sockets[socket_from] = expire_at
189 if self._sockets[socket_from] <= time.monotonic():
190 del self._sockets[socket_from]
191 raise GateInterceptException(
'Socket hits the time limit')
199 self.raise_if_timed_out(socket_from)
200 await _intercept_ok(loop, socket_from, socket_to)
203class _InterceptSmallerParts:
204 def __init__(self, max_size: int, sleep_per_packet: float):
206 self._max_size = max_size
207 self._sleep_per_packet = sleep_per_packet
215 data = await loop.sock_recv(socket_from, self._max_size)
217 raise ConnectionClosedError()
218 await asyncio.sleep(self._sleep_per_packet)
219 await loop.sock_sendall(socket_to, data)
222class _InterceptConcatPackets:
223 def __init__(self, packet_size: int):
224 assert packet_size >= 0
225 self._packet_size = packet_size
226 self._expire_at: typing.Optional[float] =
None
227 self._buf = io.BytesIO()
235 if self._expire_at
is None:
236 self._expire_at = time.monotonic() + MAX_DELAY
238 if self._expire_at <= time.monotonic():
240 f
'Failed to make a packet of sufficient size in {MAX_DELAY} '
241 'seconds. Check the test logic, it should end with checking '
242 'that the data was sent and by calling TcpGate function '
243 'to_client_pass() to pass the remaining packets.',
246 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
248 raise ConnectionClosedError()
249 self._buf.write(data)
250 if self._buf.tell() >= self._packet_size:
251 await loop.sock_sendall(socket_to, self._buf.getvalue())
252 self._buf = io.BytesIO()
253 self._expire_at =
None
256class _InterceptBytesLimit:
257 def __init__(self, bytes_limit: int, gate:
'BaseGate'):
258 assert bytes_limit >= 0
259 self._bytes_limit = bytes_limit
260 self._bytes_remain = self._bytes_limit
269 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
271 raise ConnectionClosedError()
272 if self._bytes_remain <= len(data):
273 await loop.sock_sendall(socket_to, data[0 : self._bytes_remain])
274 await self._gate.sockets_close()
275 self._bytes_remain = self._bytes_limit
276 raise GateInterceptException(
'Data transmission limit reached')
277 self._bytes_remain -= len(data)
278 await loop.sock_sendall(socket_to, data)
281class _InterceptSubstitute:
282 def __init__(self, pattern: str, repl: str, encoding=
'utf-8'):
283 self._pattern = re.compile(pattern)
285 self._encoding = encoding
293 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
295 raise ConnectionClosedError()
297 res = self._pattern.sub(self._repl, data.decode(self._encoding))
298 data = res.encode(self._encoding)
301 await loop.sock_sendall(socket_to, data)
304async def _cancel_and_join(task: typing.Optional[asyncio.Task]) ->
None:
305 if not task
or task.cancelled():
311 except asyncio.CancelledError:
314 logger.exception(
'Exception in _cancel_and_join')
317def _make_socket_nonblocking(sock: Socket) ->
None:
318 sock.setblocking(
False)
319 if sock.type == socket.SOCK_STREAM:
320 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
323class _UdpDemuxSocketMock:
325 Emulates a point-to-point connection over UDP socket
326 with a non-blocking socket interface
329 def gettimeout(self):
330 return self._sock.gettimeout()
332 def __init__(self, sock: Socket, peer_address: Address):
333 self._sock: Socket = sock
334 self._peeraddr: Address = peer_address
336 sockpair = socket.socketpair(type=socket.SOCK_DGRAM)
337 self._demux_in: Socket = sockpair[0]
338 self._demux_out: Socket = sockpair[1]
339 _make_socket_nonblocking(self._demux_in)
340 _make_socket_nonblocking(self._demux_out)
341 self._is_active: bool =
True
344 def peer_address(self):
345 return self._peeraddr
347 async def push(self, loop: EvLoop, data: bytes):
348 return await loop.sock_sendall(self._demux_in, data)
351 return self._is_active
354 self._is_active =
False
355 self._demux_out.close()
356 self._demux_in.close()
358 def recvfrom(self, bufsize: int, flags: int = 0):
359 return self._demux_out.recvfrom(bufsize, flags)
361 def recv(self, bufsize: int, flags: int = 0):
362 return self._demux_out.recv(bufsize, flags)
364 def get_demux_out(self):
365 return self._demux_out
368 return self._demux_out.fileno()
370 def send(self, data: bytes):
371 return self._sock.sendto(data, self._peeraddr)
375 def __init__(self, socket_from, socket_to, interceptor):
376 self._socket_from = socket_from
377 self._socket_to = socket_to
378 self._condition = asyncio.Condition()
379 self._interceptor = interceptor
381 def get_interceptor(self):
382 return self._interceptor
384 async def set_interceptor(self, interceptor):
385 async with self._condition:
386 self._interceptor = interceptor
387 self._condition.notify()
390 loop = asyncio.get_running_loop()
397 await _wait_for_data(self._socket_from)
400 async with self._condition:
401 interceptor = await self._condition.wait_for(self.get_interceptor)
403 logging.trace(
'running interceptor: %s', interceptor)
404 await interceptor(loop, self._socket_from, self._socket_to)
412 client: typing.Union[socket.socket, _UdpDemuxSocketMock],
413 server: socket.socket,
414 to_server_intercept: Interceptor,
415 to_client_intercept: Interceptor,
417 self._proxy_name = proxy_name
419 self._client = client
420 self._server = server
422 self._task_to_server = InterceptTask(client, server, to_server_intercept)
423 self._task_to_client = InterceptTask(server, client, to_client_intercept)
425 self._task = asyncio.create_task(self._run())
426 self._interceptor_tasks = []
428 async def set_to_server_interceptor(self, interceptor):
429 await self._task_to_server.set_interceptor(interceptor)
431 async def set_to_client_interceptor(self, interceptor: Interceptor):
432 await self._task_to_client.set_interceptor(interceptor)
434 async def shutdown(self) -> None:
435 for task
in self._interceptor_tasks:
436 await _cancel_and_join(task)
437 await _cancel_and_join(self._task)
439 def is_active(self) -> bool:
440 return not self._task.done()
442 def info(self) -> str:
443 if not self.is_active():
446 return f
'client fd={self._client.fileno()} <=> server fd={self._server.fileno()}'
448 async def _run(self):
449 self._interceptor_tasks = [
450 asyncio.create_task(obj.run())
for obj
in (self._task_to_server, self._task_to_client)
453 done, _ = await asyncio.wait(self._interceptor_tasks, return_when=asyncio.FIRST_EXCEPTION)
456 except GateInterceptException
as exc:
457 logger.info(
'In "%s": %s', self._proxy_name, exc)
458 except socket.error
as exc:
459 logger.error(
'Exception in "%s": %s', self._proxy_name, exc)
461 logger.exception(
'interceptor failed')
463 for task
in self._interceptor_tasks:
468 for sock
in self._server, self._client:
472 logger.exception(
'Exception in "%s" on closing %s:', self._proxy_name, sock)
837class TcpGate(BaseGate):
839 Implements TCP chaos-proxy logic such as accepting incoming tcp client
840 connections. On each new connection new tcp client connects to server
841 (host_to_server, port_to_server).
843 @ingroup userver_testsuite
845 @see @ref scripts/docs/en/userver/chaos_testing.md
848 def __init__(self, route: GateRoute, loop: typing.Optional[EvLoop] = None) -> None:
849 self._connected_event = asyncio.Event()
850 super().__init__(route, loop)
852 def connections_count(self) -> int:
854 Returns maximal amount of connections going through the gate at
857 @warning Some of the connections could be closing, or could be opened
858 right before the function starts. Use with caution!
860 return len(self._sockets)
852 def connections_count(self) -> int:
…
862 async def wait_for_connections(self, *, count=1, timeout=0.0) -> None:
864 Wait for at least `count` connections going through the gate.
866 @throws asyncio.TimeoutError exception if failed to get the
867 required amount of connections in time.
870 while self.connections_count() < count:
871 await self._connected_event.wait()
872 self._connected_event.clear()
875 deadline = time.monotonic() + timeout
876 while self.connections_count() < count:
877 time_left = deadline - time.monotonic()
878 await asyncio.wait_for(
879 self._connected_event.wait(),
882 self._connected_event.clear()
862 async def wait_for_connections(self, *, count=1, timeout=0.0) -> None:
…
884 def _create_accepting_sockets(self) -> typing.List[Socket]:
885 res: typing.List[Socket] = []
886 for addr in socket.getaddrinfo(
887 self._route.host_for_client,
888 self._route.port_for_client,
889 type=socket.SOCK_STREAM,
891 sock = Socket(addr[0], addr[1])
892 _make_socket_nonblocking(sock)
893 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
897 f'Accepting connections on {sock.getsockname()}, fd={sock.fileno()}',
903 async def _connect_to_server(self):
904 addrs = await self._loop.getaddrinfo(
905 self._route.host_to_server,
906 self._route.port_to_server,
907 type=socket.SOCK_STREAM,
910 server = Socket(addr[0], addr[1])
911 _make_socket_nonblocking(server)
913 await self._loop.sock_connect(server, addr[4])
914 logging.trace('Connected to %s', addr[4])
916 except Exception as exc: # pylint: disable=broad-except
918 logging.warning('Could not connect to %s: %s', addr[4], exc)
920 async def _do_accept(self, accept_sock: Socket) -> None:
922 client, _ = await self._loop.sock_accept(accept_sock)
923 _make_socket_nonblocking(client)
925 server = await self._connect_to_server()
933 self._to_server_intercept,
934 self._to_client_intercept,
937 self._connected_event.set()
941 self._collect_garbage()
920 async def _do_accept(self, accept_sock: Socket) -> None:
…
944class UdpGate(BaseGate):
946 Implements UDP chaos-proxy logic such as demuxing incoming datagrams
947 from different clients.
948 Separate connections to server are made for each new client.
950 @ingroup userver_testsuite
952 @see @ref scripts/docs/en/userver/chaos_testing.md
955 def __init__(self, route: GateRoute, loop: typing.Optional[EvLoop] = None):
956 self._clients: typing.Set[_UdpDemuxSocketMock] = set()
957 super().__init__(route, loop)
959 def is_connected(self) -> bool:
961 Returns True if there is active pair of sockets ready to transfer data
964 return len(self._sockets) > 0
959 def is_connected(self) -> bool:
…
966 def _create_accepting_sockets(self) -> typing.List[Socket]:
967 res: typing.List[Socket] = []
968 for addr in socket.getaddrinfo(
969 self._route.host_for_client,
970 self._route.port_for_client,
971 type=socket.SOCK_DGRAM,
973 sock = socket.socket(addr[0], addr[1])
974 _make_socket_nonblocking(sock)
976 logger.debug(f'Accepting connections on {sock.getsockname()}')
981 async def _connect_to_server(self):
982 addrs = await self._loop.getaddrinfo(
983 self._route.host_to_server,
984 self._route.port_to_server,
985 type=socket.SOCK_DGRAM,
988 server = Socket(addr[0], addr[1])
990 _make_socket_nonblocking(server)
991 await self._loop.sock_connect(server, addr[4])
992 logging.trace('Connected to %s', addr[4])
994 except Exception as exc: # pylint: disable=broad-except
995 logging.warning('Could not connect to %s: %s', addr[4], exc)
997 def _collect_garbage(self) -> None:
998 super()._collect_garbage()
999 self._clients = {c for c in self._clients if c.is_active()}
1001 async def _do_accept(self, accept_sock: Socket):
1002 sock = asyncio_socket.from_socket(accept_sock)
1004 data, addr = await sock.recvfrom(RECV_MAX_SIZE, timeout=60.0)
1006 client: typing.Optional[_UdpDemuxSocketMock] = None
1007 for known_clients in self._clients:
1008 if addr == known_clients.peer_address:
1009 client = known_clients
1013 server = await self._connect_to_server()
1018 client = _UdpDemuxSocketMock(accept_sock, addr)
1019 self._clients.add(client)
1027 self._to_server_intercept,
1028 self._to_client_intercept,
1032 await client.push(self._loop, data)
1033 self._collect_garbage()
1001 async def _do_accept(self, accept_sock: Socket):
…
1035 async def to_server_concat_packets(self, packet_size: int) -> None:
1036 raise NotImplementedError('Udp packets cannot be concatenated')
1035 async def to_server_concat_packets(self, packet_size: int) -> None:
…
1038 async def to_client_concat_packets(self, packet_size: int) -> None:
1039 raise NotImplementedError('Udp packets cannot be concatenated')
1038 async def to_client_concat_packets(self, packet_size: int) -> None:
…
1041 async def to_server_smaller_parts(
1045 sleep_per_packet: float = 0,
1047 raise NotImplementedError('Udp packets cannot be split')
1041 async def to_server_smaller_parts(
…
1049 async def to_client_smaller_parts(
1053 sleep_per_packet: float = 0,
1055 raise NotImplementedError('Udp packets cannot be split')
1049 async def to_client_smaller_parts(
…