39 Class that describes the route for TcpGate or UdpGate.
41 Use `port_for_client == 0` to bind to some unused port. In that case the
42 actual address could be retrieved via BaseGate.get_sockname_for_clients().
44 @ingroup userver_testsuite
50 host_for_client: str =
'127.0.0.1'
51 port_for_client: int = 0
61logger = logging.getLogger(__name__)
64Address = typing.Tuple[str, int]
67Interceptor = typing.Callable[
68 [EvLoop, Socket, Socket],
69 typing.Coroutine[typing.Any, typing.Any,
None],
73class GateException(Exception):
77class GateInterceptException(Exception):
81async def _intercept_ok(
86 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
88 raise ConnectionClosedError()
89 await loop.sock_sendall(socket_to, data)
92async def _intercept_drop(
97 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
99 raise ConnectionClosedError()
102async def _intercept_delay(
108 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
110 raise ConnectionClosedError()
111 await asyncio.sleep(delay)
112 await loop.sock_sendall(socket_to, data)
115async def _intercept_close_on_data(
120 data = await loop.sock_recv(socket_from, 1)
122 raise ConnectionClosedError()
123 raise GateInterceptException(
'Closing socket on data')
126async def _intercept_corrupt(
131 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
133 raise ConnectionClosedError()
134 await loop.sock_sendall(socket_to, bytearray([
not x
for x
in data]))
137class _InterceptBpsLimit:
138 def __init__(self, bytes_per_second: float):
139 assert bytes_per_second >= 1
140 self._bytes_per_second = bytes_per_second
141 self._time_last_added = 0.0
142 self._bytes_left = self._bytes_per_second
144 def _update_limit(self) -> None:
145 current_time = time.monotonic()
146 elapsed = current_time - self._time_last_added
147 bytes_addition = self._bytes_per_second * elapsed
148 if bytes_addition > 0:
149 self._bytes_left += bytes_addition
150 self._time_last_added = current_time
152 if self._bytes_left > self._bytes_per_second:
153 self._bytes_left = self._bytes_per_second
163 bytes_to_recv = min(int(self._bytes_left), RECV_MAX_SIZE)
164 if bytes_to_recv > 0:
165 data = await loop.sock_recv(socket_from, bytes_to_recv)
167 raise ConnectionClosedError()
168 self._bytes_left -= len(data)
170 await loop.sock_sendall(socket_to, data)
172 logger.info(
'Socket hits the bytes per second limit')
173 await asyncio.sleep(1.0 / self._bytes_per_second)
176class _InterceptTimeLimit:
177 def __init__(self, timeout: float, jitter: float):
178 self._sockets: typing.Dict[Socket, float] = {}
179 assert timeout >= 0.0
180 self._timeout = timeout
182 self._jitter = jitter
184 def raise_if_timed_out(self, socket_from: Socket) ->
None:
185 if socket_from
not in self._sockets:
186 jitter = self._jitter * random.random()
187 expire_at = time.monotonic() + self._timeout + jitter
188 self._sockets[socket_from] = expire_at
190 if self._sockets[socket_from] <= time.monotonic():
191 del self._sockets[socket_from]
192 raise GateInterceptException(
'Socket hits the time limit')
200 self.raise_if_timed_out(socket_from)
201 await _intercept_ok(loop, socket_from, socket_to)
204class _InterceptSmallerParts:
205 def __init__(self, max_size: int, sleep_per_packet: float):
207 self._max_size = max_size
208 self._sleep_per_packet = sleep_per_packet
216 data = await loop.sock_recv(socket_from, self._max_size)
218 raise ConnectionClosedError()
219 await asyncio.sleep(self._sleep_per_packet)
220 await loop.sock_sendall(socket_to, data)
223class _InterceptConcatPackets:
224 def __init__(self, packet_size: int):
225 assert packet_size >= 0
226 self._packet_size = packet_size
227 self._expire_at: typing.Optional[float] =
None
228 self._buf = io.BytesIO()
236 if self._expire_at
is None:
237 self._expire_at = time.monotonic() + MAX_DELAY
239 if self._expire_at <= time.monotonic():
241 f
'Failed to make a packet of sufficient size in {MAX_DELAY} '
242 'seconds. Check the test logic, it should end with checking '
243 'that the data was sent and by calling TcpGate function '
244 'to_client_pass() to pass the remaining packets.',
247 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
249 raise ConnectionClosedError()
250 self._buf.write(data)
251 if self._buf.tell() >= self._packet_size:
252 await loop.sock_sendall(socket_to, self._buf.getvalue())
253 self._buf = io.BytesIO()
254 self._expire_at =
None
257class _InterceptBytesLimit:
258 def __init__(self, bytes_limit: int, gate:
'BaseGate'):
259 assert bytes_limit >= 0
260 self._bytes_limit = bytes_limit
261 self._bytes_remain = self._bytes_limit
270 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
272 raise ConnectionClosedError()
273 if self._bytes_remain <= len(data):
274 await loop.sock_sendall(socket_to, data[0 : self._bytes_remain])
275 await self._gate.sockets_close()
276 self._bytes_remain = self._bytes_limit
277 raise GateInterceptException(
'Data transmission limit reached')
278 self._bytes_remain -= len(data)
279 await loop.sock_sendall(socket_to, data)
282class _InterceptSubstitute:
283 def __init__(self, pattern: str, repl: str, encoding=
'utf-8'):
284 self._pattern = re.compile(pattern)
286 self._encoding = encoding
294 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
296 raise ConnectionClosedError()
298 res = self._pattern.sub(self._repl, data.decode(self._encoding))
299 data = res.encode(self._encoding)
302 await loop.sock_sendall(socket_to, data)
305async def _cancel_and_join(task: typing.Optional[asyncio.Task]) ->
None:
306 if not task
or task.cancelled():
312 except asyncio.CancelledError:
315 logger.exception(
'Exception in _cancel_and_join')
318def _make_socket_nonblocking(sock: Socket) ->
None:
319 sock.setblocking(
False)
320 if sock.type == socket.SOCK_STREAM:
321 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
324class _UdpDemuxSocketMock:
326 Emulates a point-to-point connection over UDP socket
327 with a non-blocking socket interface
330 def gettimeout(self):
331 return self._sock.gettimeout()
333 def __init__(self, sock: Socket, peer_address: Address):
334 self._sock: Socket = sock
335 self._peeraddr: Address = peer_address
337 sockpair = socket.socketpair(type=socket.SOCK_DGRAM)
338 self._demux_in: Socket = sockpair[0]
339 self._demux_out: Socket = sockpair[1]
340 _make_socket_nonblocking(self._demux_in)
341 _make_socket_nonblocking(self._demux_out)
342 self._is_active: bool =
True
345 def peer_address(self):
346 return self._peeraddr
348 async def push(self, loop: EvLoop, data: bytes):
349 return await loop.sock_sendall(self._demux_in, data)
352 return self._is_active
355 self._is_active =
False
356 self._demux_out.close()
357 self._demux_in.close()
359 def recvfrom(self, bufsize: int, flags: int = 0):
360 return self._demux_out.recvfrom(bufsize, flags)
362 def recv(self, bufsize: int, flags: int = 0):
363 return self._demux_out.recv(bufsize, flags)
365 def get_demux_out(self):
366 return self._demux_out
369 return self._demux_out.fileno()
371 def send(self, data: bytes):
372 return self._sock.sendto(data, self._peeraddr)
376 def __init__(self, socket_from, socket_to, interceptor):
377 self._socket_from = socket_from
378 self._socket_to = socket_to
379 self._condition = asyncio.Condition()
380 self._interceptor = interceptor
382 def get_interceptor(self):
383 return self._interceptor
385 async def set_interceptor(self, interceptor):
386 async with self._condition:
387 self._interceptor = interceptor
388 self._condition.notify()
391 loop = asyncio.get_running_loop()
398 await _wait_for_data(self._socket_from)
401 async with self._condition:
402 interceptor = await self._condition.wait_for(self.get_interceptor)
404 logging.trace(
'running interceptor: %s', interceptor)
405 await interceptor(loop, self._socket_from, self._socket_to)
413 client: typing.Union[socket.socket, _UdpDemuxSocketMock],
414 server: socket.socket,
415 to_server_intercept: Interceptor,
416 to_client_intercept: Interceptor,
418 self._proxy_name = proxy_name
420 self._client = client
421 self._server = server
423 self._task_to_server = InterceptTask(client, server, to_server_intercept)
424 self._task_to_client = InterceptTask(server, client, to_client_intercept)
426 self._task = asyncio.create_task(self._run())
427 self._interceptor_tasks = []
429 async def set_to_server_interceptor(self, interceptor):
430 await self._task_to_server.set_interceptor(interceptor)
432 async def set_to_client_interceptor(self, interceptor: Interceptor):
433 await self._task_to_client.set_interceptor(interceptor)
435 async def shutdown(self) -> None:
436 for task
in self._interceptor_tasks:
437 await _cancel_and_join(task)
438 await _cancel_and_join(self._task)
440 def is_active(self) -> bool:
441 return not self._task.done()
443 def info(self) -> str:
444 if not self.is_active():
447 return f
'client fd={self._client.fileno()} <=> server fd={self._server.fileno()}'
449 async def _run(self):
450 self._interceptor_tasks = [
451 asyncio.create_task(obj.run())
for obj
in (self._task_to_server, self._task_to_client)
454 done, _ = await asyncio.wait(self._interceptor_tasks, return_when=asyncio.FIRST_EXCEPTION)
457 except GateInterceptException
as exc:
458 logger.info(
'In "%s": %s', self._proxy_name, exc)
459 except socket.error
as exc:
460 logger.error(
'Exception in "%s": %s', self._proxy_name, exc)
462 logger.exception(
'interceptor failed')
464 for task
in self._interceptor_tasks:
469 for sock
in self._server, self._client:
473 logger.exception(
'Exception in "%s" on closing %s:', self._proxy_name, sock)
838class TcpGate(BaseGate):
840 Implements TCP chaos-proxy logic such as accepting incoming tcp client
841 connections. On each new connection new tcp client connects to server
842 (host_to_server, port_to_server).
844 @ingroup userver_testsuite
846 @see @ref scripts/docs/en/userver/chaos_testing.md
849 def __init__(self, route: GateRoute, loop: typing.Optional[EvLoop] = None) -> None:
850 self._connected_event = asyncio.Event()
851 super().__init__(route, loop)
853 def connections_count(self) -> int:
855 Returns maximal amount of connections going through the gate at
858 @warning Some of the connections could be closing, or could be opened
859 right before the function starts. Use with caution!
861 return len(self._sockets)
863 async def wait_for_connections(self, *, count=1, timeout=0.0) -> None:
865 Wait for at least `count` connections going through the gate.
867 @throws asyncio.TimeoutError exception if failed to get the
868 required amount of connections in time.
871 while self.connections_count() < count:
872 await self._connected_event.wait()
873 self._connected_event.clear()
876 deadline = time.monotonic() + timeout
877 while self.connections_count() < count:
878 time_left = deadline - time.monotonic()
879 await asyncio.wait_for(
880 self._connected_event.wait(),
883 self._connected_event.clear()
885 def _create_accepting_sockets(self) -> typing.List[Socket]:
886 res: typing.List[Socket] = []
887 for addr in socket.getaddrinfo(
888 self._route.host_for_client,
889 self._route.port_for_client,
890 type=socket.SOCK_STREAM,
892 sock = Socket(addr[0], addr[1])
893 _make_socket_nonblocking(sock)
894 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
898 f'Accepting connections on {sock.getsockname()}, fd={sock.fileno()}',
904 async def _connect_to_server(self):
905 addrs = await self._loop.getaddrinfo(
906 self._route.host_to_server,
907 self._route.port_to_server,
908 type=socket.SOCK_STREAM,
911 server = Socket(addr[0], addr[1])
912 _make_socket_nonblocking(server)
914 await self._loop.sock_connect(server, addr[4])
915 logging.trace('Connected to %s', addr[4])
917 except Exception as exc: # pylint: disable=broad-except
919 logging.warning('Could not connect to %s: %s', addr[4], exc)
921 async def _do_accept(self, accept_sock: Socket) -> None:
923 client, _ = await self._loop.sock_accept(accept_sock)
924 _make_socket_nonblocking(client)
926 server = await self._connect_to_server()
934 self._to_server_intercept,
935 self._to_client_intercept,
938 self._connected_event.set()
942 self._collect_garbage()
945class UdpGate(BaseGate):
947 Implements UDP chaos-proxy logic such as demuxing incoming datagrams
948 from different clients.
949 Separate connections to server are made for each new client.
951 @ingroup userver_testsuite
953 @see @ref scripts/docs/en/userver/chaos_testing.md
956 def __init__(self, route: GateRoute, loop: typing.Optional[EvLoop] = None):
957 self._clients: typing.Set[_UdpDemuxSocketMock] = set()
958 super().__init__(route, loop)
960 def is_connected(self) -> bool:
962 Returns True if there is active pair of sockets ready to transfer data
965 return len(self._sockets) > 0
967 def _create_accepting_sockets(self) -> typing.List[Socket]:
968 res: typing.List[Socket] = []
969 for addr in socket.getaddrinfo(
970 self._route.host_for_client,
971 self._route.port_for_client,
972 type=socket.SOCK_DGRAM,
974 sock = socket.socket(addr[0], addr[1])
975 _make_socket_nonblocking(sock)
977 logger.debug(f'Accepting connections on {sock.getsockname()}')
982 async def _connect_to_server(self):
983 addrs = await self._loop.getaddrinfo(
984 self._route.host_to_server,
985 self._route.port_to_server,
986 type=socket.SOCK_DGRAM,
989 server = Socket(addr[0], addr[1])
991 _make_socket_nonblocking(server)
992 await self._loop.sock_connect(server, addr[4])
993 logging.trace('Connected to %s', addr[4])
995 except Exception as exc: # pylint: disable=broad-except
996 logging.warning('Could not connect to %s: %s', addr[4], exc)
998 def _collect_garbage(self) -> None:
999 super()._collect_garbage()
1000 self._clients = {c for c in self._clients if c.is_active()}
1002 async def _do_accept(self, accept_sock: Socket):
1003 sock = asyncio_socket.from_socket(accept_sock)
1005 data, addr = await sock.recvfrom(RECV_MAX_SIZE, timeout=60.0)
1007 client: typing.Optional[_UdpDemuxSocketMock] = None
1008 for known_clients in self._clients:
1009 if addr == known_clients.peer_address:
1010 client = known_clients
1014 server = await self._connect_to_server()
1019 client = _UdpDemuxSocketMock(accept_sock, addr)
1020 self._clients.add(client)
1028 self._to_server_intercept,
1029 self._to_client_intercept,
1033 await client.push(self._loop, data)
1034 self._collect_garbage()
1036 async def to_server_concat_packets(self, packet_size: int) -> None:
1037 raise NotImplementedError('Udp packets cannot be concatenated')
1039 async def to_client_concat_packets(self, packet_size: int) -> None:
1040 raise NotImplementedError('Udp packets cannot be concatenated')
1042 async def to_server_smaller_parts(
1046 sleep_per_packet: float = 0,
1048 raise NotImplementedError('Udp packets cannot be split')
1050 async def to_client_smaller_parts(
1054 sleep_per_packet: float = 0,
1056 raise NotImplementedError('Udp packets cannot be split')