userver: testsuite/pytest_plugins/pytest_userver/chaos.py
Loading...
Searching...
No Matches
testsuite/pytest_plugins/pytest_userver/chaos.py
1# pylint: disable=too-many-lines
2"""
3Python module that provides testsuite support for
4chaos tests; see
5@ref scripts/docs/en/userver/chaos_testing.md for an introduction.
6
7@ingroup userver_testsuite
8"""
9
10from __future__ import annotations
11
12import asyncio
13from collections.abc import Callable
14from collections.abc import Coroutine
15import dataclasses
16import functools
17import io
18import logging
19import random
20import re
21import socket
22import time
23from typing import Any
24from typing import TypeAlias
25
26import pytest
27
28from testsuite import asyncio_socket
29from testsuite.utils import callinfo
30
31
32class BaseError(Exception):
33 pass
34
35
36class ConnectionClosedError(BaseError):
37 pass
38
39
40@dataclasses.dataclass(frozen=True)
41class GateRoute:
42 """
43 Class that describes the route for TcpGate or UdpGate.
44
45 Use `port_for_client == 0` to bind to some unused port. In that case the
46 actual address could be retrieved via BaseGate.get_sockname_for_clients().
47
48 @ingroup userver_testsuite
49 """
50
51 name: str
52 host_to_server: str
53 port_to_server: int
54 host_for_client: str = '127.0.0.1'
55 port_for_client: int = 0
56
57
58# @cond
59
60# https://docs.python.org/3/library/socket.html#socket.socket.recv
61RECV_MAX_SIZE = 4096
62MAX_DELAY = 60.0
63
64
65logger = logging.getLogger(__name__)
66
67
68Address: TypeAlias = tuple[str, int]
69EvLoop: TypeAlias = Any
70Socket: TypeAlias = socket.socket
71Interceptor: TypeAlias = Callable[[EvLoop, Socket, Socket], Coroutine[Any, Any, None]]
72
73
74class GateException(Exception):
75 pass
76
77
78class GateInterceptException(Exception):
79 pass
80
81
82async def _intercept_ok(
83 loop: EvLoop,
84 socket_from: Socket,
85 socket_to: Socket,
86) -> None:
87 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
88 if not data:
89 raise ConnectionClosedError()
90 await loop.sock_sendall(socket_to, data)
91
92
93async def _intercept_drop(
94 loop: EvLoop,
95 socket_from: Socket,
96 socket_to: Socket,
97) -> None:
98 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
99 if not data:
100 raise ConnectionClosedError()
101
102
103async def _intercept_delay(
104 delay: float,
105 loop: EvLoop,
106 socket_from: Socket,
107 socket_to: Socket,
108) -> None:
109 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
110 if not data:
111 raise ConnectionClosedError()
112 await asyncio.sleep(delay)
113 await loop.sock_sendall(socket_to, data)
114
115
116async def _intercept_close_on_data(
117 loop: EvLoop,
118 socket_from: Socket,
119 socket_to: Socket,
120) -> None:
121 data = await loop.sock_recv(socket_from, 1)
122 if not data:
123 raise ConnectionClosedError()
124 raise GateInterceptException('Closing socket on data')
125
126
127async def _intercept_corrupt(
128 loop: EvLoop,
129 socket_from: Socket,
130 socket_to: Socket,
131) -> None:
132 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
133 if not data:
134 raise ConnectionClosedError()
135 await loop.sock_sendall(socket_to, bytearray([not x for x in data]))
136
137
138class _InterceptBpsLimit:
139 def __init__(self, bytes_per_second: float):
140 assert bytes_per_second >= 1
141 self._bytes_per_second = bytes_per_second
142 self._time_last_added = 0.0
143 self._bytes_left = self._bytes_per_second
144
145 def _update_limit(self) -> None:
146 current_time = time.monotonic()
147 elapsed = current_time - self._time_last_added
148 bytes_addition = self._bytes_per_second * elapsed
149 if bytes_addition > 0:
150 self._bytes_left += bytes_addition
151 self._time_last_added = current_time
152
153 if self._bytes_left > self._bytes_per_second:
154 self._bytes_left = self._bytes_per_second
155
156 async def __call__(
157 self,
158 loop: EvLoop,
159 socket_from: Socket,
160 socket_to: Socket,
161 ) -> None:
162 self._update_limit()
163
164 bytes_to_recv = min(int(self._bytes_left), RECV_MAX_SIZE)
165 if bytes_to_recv > 0:
166 data = await loop.sock_recv(socket_from, bytes_to_recv)
167 if not data:
168 raise ConnectionClosedError()
169 self._bytes_left -= len(data)
170
171 await loop.sock_sendall(socket_to, data)
172 else:
173 logger.info('Socket hits the bytes per second limit')
174 await asyncio.sleep(1.0 / self._bytes_per_second)
175
176
177class _InterceptTimeLimit:
178 def __init__(self, timeout: float, jitter: float):
179 self._sockets: dict[Socket, float] = {}
180 assert timeout >= 0.0
181 self._timeout = timeout
182 assert jitter >= 0.0
183 self._jitter = jitter
184
185 def raise_if_timed_out(self, socket_from: Socket) -> None:
186 if socket_from not in self._sockets:
187 jitter = self._jitter * random.random()
188 expire_at = time.monotonic() + self._timeout + jitter
189 self._sockets[socket_from] = expire_at
190
191 if self._sockets[socket_from] <= time.monotonic():
192 del self._sockets[socket_from]
193 raise GateInterceptException('Socket hits the time limit')
194
195 async def __call__(
196 self,
197 loop: EvLoop,
198 socket_from: Socket,
199 socket_to: Socket,
200 ) -> None:
201 self.raise_if_timed_out(socket_from)
202 await _intercept_ok(loop, socket_from, socket_to)
203
204
205class _InterceptSmallerParts:
206 def __init__(self, max_size: int, sleep_per_packet: float):
207 assert max_size > 0
208 self._max_size = max_size
209 self._sleep_per_packet = sleep_per_packet
210
211 async def __call__(
212 self,
213 loop: EvLoop,
214 socket_from: Socket,
215 socket_to: Socket,
216 ) -> None:
217 data = await loop.sock_recv(socket_from, self._max_size)
218 if not data:
219 raise ConnectionClosedError()
220 await asyncio.sleep(self._sleep_per_packet)
221 await loop.sock_sendall(socket_to, data)
222
223
224class _InterceptConcatPackets:
225 def __init__(self, packet_size: int):
226 assert packet_size >= 0
227 self._packet_size = packet_size
228 self._expire_at: float | None = None
229 self._buf = io.BytesIO()
230
231 async def __call__(
232 self,
233 loop: EvLoop,
234 socket_from: Socket,
235 socket_to: Socket,
236 ) -> None:
237 if self._expire_at is None:
238 self._expire_at = time.monotonic() + MAX_DELAY
239
240 if self._expire_at <= time.monotonic():
241 pytest.fail(
242 f'Failed to make a packet of sufficient size in {MAX_DELAY} '
243 'seconds. Check the test logic, it should end with checking '
244 'that the data was sent and by calling TcpGate function '
245 'to_client_pass() to pass the remaining packets.',
246 )
247
248 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
249 if not data:
250 raise ConnectionClosedError()
251 self._buf.write(data)
252 if self._buf.tell() >= self._packet_size:
253 await loop.sock_sendall(socket_to, self._buf.getvalue())
254 self._buf = io.BytesIO()
255 self._expire_at = None
256
257
258class _InterceptBytesLimit:
259 def __init__(self, bytes_limit: int, gate: BaseGate):
260 assert bytes_limit >= 0
261 self._bytes_limit = bytes_limit
262 self._bytes_remain = self._bytes_limit
263 self._gate = gate
264
265 async def __call__(
266 self,
267 loop: EvLoop,
268 socket_from: Socket,
269 socket_to: Socket,
270 ) -> None:
271 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
272 if not data:
273 raise ConnectionClosedError()
274 if self._bytes_remain <= len(data):
275 await loop.sock_sendall(socket_to, data[0 : self._bytes_remain])
276 await self._gate.sockets_close()
277 self._bytes_remain = self._bytes_limit
278 raise GateInterceptException('Data transmission limit reached')
279 self._bytes_remain -= len(data)
280 await loop.sock_sendall(socket_to, data)
281
282
283class _InterceptSubstitute:
284 def __init__(self, pattern: str, repl: str, encoding='utf-8'):
285 self._pattern = re.compile(pattern)
286 self._repl = repl
287 self._encoding = encoding
288
289 async def __call__(
290 self,
291 loop: EvLoop,
292 socket_from: Socket,
293 socket_to: Socket,
294 ) -> None:
295 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
296 if not data:
297 raise ConnectionClosedError()
298 try:
299 res = self._pattern.sub(self._repl, data.decode(self._encoding))
300 data = res.encode(self._encoding)
301 except UnicodeError:
302 pass
303 await loop.sock_sendall(socket_to, data)
304
305
306async def _cancel_and_join(task: asyncio.Task | None) -> None:
307 if not task or task.cancelled():
308 return
309
310 try:
311 task.cancel()
312 await task
313 except asyncio.CancelledError:
314 return
315 except Exception: # pylint: disable=broad-except
316 logger.exception('Exception in _cancel_and_join')
317
318
319def _make_socket_nonblocking(sock: Socket) -> None:
320 sock.setblocking(False)
321 if sock.type == socket.SOCK_STREAM:
322 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
323
324
325class _UdpDemuxSocketMock:
326 """
327 Emulates a point-to-point connection over UDP socket
328 with a non-blocking socket interface
329 """
330
331 def gettimeout(self):
332 return self._sock.gettimeout()
333
334 def __init__(self, sock: Socket, peer_address: Address):
335 self._sock: Socket = sock
336 self._peeraddr: Address = peer_address
337
338 sockpair = socket.socketpair(type=socket.SOCK_DGRAM)
339 self._demux_in: Socket = sockpair[0]
340 self._demux_out: Socket = sockpair[1]
341 _make_socket_nonblocking(self._demux_in)
342 _make_socket_nonblocking(self._demux_out)
343 self._is_active: bool = True
344
345 @property
346 def peer_address(self):
347 return self._peeraddr
348
349 async def push(self, loop: EvLoop, data: bytes):
350 return await loop.sock_sendall(self._demux_in, data)
351
352 def is_active(self):
353 return self._is_active
354
355 def close(self):
356 self._is_active = False
357 self._demux_out.close()
358 self._demux_in.close()
359
360 def recvfrom(self, bufsize: int, flags: int = 0):
361 return self._demux_out.recvfrom(bufsize, flags)
362
363 def recv(self, bufsize: int, flags: int = 0):
364 return self._demux_out.recv(bufsize, flags)
365
366 def get_demux_out(self):
367 return self._demux_out
368
369 def fileno(self):
370 return self._demux_out.fileno()
371
372 def send(self, data: bytes):
373 return self._sock.sendto(data, self._peeraddr)
374
375
376class InterceptTask:
377 def __init__(self, socket_from, socket_to, interceptor):
378 self._socket_from = socket_from
379 self._socket_to = socket_to
380 self._condition = asyncio.Condition()
381 self._interceptor = interceptor
382
383 def get_interceptor(self):
384 return self._interceptor
385
386 async def set_interceptor(self, interceptor):
387 async with self._condition:
388 self._interceptor = interceptor
389 self._condition.notify()
390
391 async def run(self):
392 loop = asyncio.get_running_loop()
393 while True:
394 # Applies new interceptors faster.
395 #
396 # To avoid long awaiting on sock_recv in an outdated
397 # interceptor we wait for data before grabbing and applying
398 # the interceptor.
399 await _wait_for_data(self._socket_from)
400
401 # Wait for interceptor attached
402 async with self._condition:
403 interceptor = await self._condition.wait_for(self.get_interceptor)
404
405 logging.trace('running interceptor: %s', interceptor)
406 await interceptor(loop, self._socket_from, self._socket_to)
407
408
409class _SocketsPaired:
410 def __init__(
411 self,
412 proxy_name: str,
413 loop: EvLoop,
414 client: socket.socket | _UdpDemuxSocketMock,
415 server: socket.socket,
416 to_server_intercept: Interceptor,
417 to_client_intercept: Interceptor,
418 ) -> None:
419 self._proxy_name = proxy_name
420
421 self._client = client
422 self._server = server
423
424 self._task_to_server = InterceptTask(client, server, to_server_intercept)
425 self._task_to_client = InterceptTask(server, client, to_client_intercept)
426
427 self._task = asyncio.create_task(self._run())
428 self._interceptor_tasks = []
429
430 async def set_to_server_interceptor(self, interceptor):
431 await self._task_to_server.set_interceptor(interceptor)
432
433 async def set_to_client_interceptor(self, interceptor: Interceptor):
434 await self._task_to_client.set_interceptor(interceptor)
435
436 async def shutdown(self) -> None:
437 for task in self._interceptor_tasks:
438 await _cancel_and_join(task)
439 await _cancel_and_join(self._task)
440
441 def is_active(self) -> bool:
442 return not self._task.done()
443
444 def info(self) -> str:
445 if not self.is_active():
446 return '<inactive>'
447
448 return f'client fd={self._client.fileno()} <=> server fd={self._server.fileno()}'
449
450 async def _run(self):
451 self._interceptor_tasks = [
452 asyncio.create_task(obj.run()) for obj in (self._task_to_server, self._task_to_client)
453 ]
454 try:
455 done, _ = await asyncio.wait(self._interceptor_tasks, return_when=asyncio.FIRST_EXCEPTION)
456 for task in done:
457 task.result()
458 except GateInterceptException as exc:
459 logger.info('In "%s": %s', self._proxy_name, exc)
460 except OSError as exc:
461 logger.error('Exception in "%s": %s', self._proxy_name, exc)
462 except Exception:
463 logger.exception('interceptor failed')
464 finally:
465 for task in self._interceptor_tasks:
466 task.cancel()
467
468 # Closing the sockets here so that the self.shutdown()
469 # returns only when the sockets are actually closed
470 for sock in self._server, self._client:
471 try:
472 sock.close()
473 except OSError:
474 logger.exception('Exception in "%s" on closing %s:', self._proxy_name, sock)
475
476
477# @endcond
478
479
480class BaseGate:
481 """
482 This base class maintain endpoints of two types:
483
484 Server-side endpoints to receive messages from clients. Address of this
485 endpoint is described by (host_for_client, port_for_client).
486
487 Client-side endpoints to forward messages to server. Server must listen on
488 (host_to_server, port_to_server).
489
490 Asynchronously concurrently passes data from client to server and from
491 server to client, allowing intercepting the data, injecting delays and
492 dropping connections.
493
494 @warning Do not use this class itself. Use one of the specifications
495 TcpGate or UdpGate
496
497 @ingroup userver_testsuite
498
499 @see @ref scripts/docs/en/userver/chaos_testing.md
500 """
501
502 _NOT_IMPLEMENTED_MESSAGE = 'Do not use BaseGate itself, use one of specializations TcpGate or UdpGate'
503
504 def __init__(self, route: GateRoute, loop: EvLoop | None = None) -> None:
505 self._route = route
506 if loop is None:
507 loop = asyncio.get_running_loop()
508 self._loop = loop
509
510 self._to_server_intercept: Interceptor = _intercept_ok
511 self._to_client_intercept: Interceptor = _intercept_ok
512
513 self._accept_sockets: list[socket.socket] = []
514 self._accept_tasks: list[asyncio.Task[None]] = []
515
516 self._sockets: set[_SocketsPaired] = set()
517
518 async def __aenter__(self) -> BaseGate:
519 self.start()
520 return self
521
522 async def __aexit__(self, exc_type, exc_value, traceback) -> None:
523 await self.stop()
524
525 def _create_accepting_sockets(self) -> list[Socket]:
526 raise NotImplementedError(self._NOT_IMPLEMENTED_MESSAGE)
527
528 def start(self):
529 """Open the socket and start accepting tasks"""
530 if self._accept_sockets:
531 return
532
533 self._accept_sockets.extend(self._create_accepting_sockets())
534
535 if not self._accept_sockets:
536 raise GateException(
537 f'Could not resolve hostname {self._route.host_for_client}',
538 )
539
540 if self._route.port_for_client == 0:
541 # In case of stop()+start() bind to the same port
542 self._route = GateRoute(
543 name=self._route.name,
544 host_to_server=self._route.host_to_server,
545 port_to_server=self._route.port_to_server,
546 host_for_client=self._accept_sockets[0].getsockname()[0],
547 port_for_client=self._accept_sockets[0].getsockname()[1],
548 )
549
550 self.start_accepting()
551
552 def start_accepting(self) -> None:
553 """Start accepting tasks"""
554 assert self._accept_sockets
555 if not all(tsk.done() for tsk in self._accept_tasks):
556 return
557
558 self._accept_tasks.clear()
559 for sock in self._accept_sockets:
560 self._accept_tasks.append(
561 asyncio.create_task(self._do_accept(sock)),
562 )
563
564 async def stop_accepting(self) -> None:
565 """
566 Stop accepting tasks without closing the accepting socket.
567 """
568 for tsk in self._accept_tasks:
569 await _cancel_and_join(tsk)
570 self._accept_tasks.clear()
571
572 async def stop(self) -> None:
573 """
574 Stop accepting tasks, close all the sockets
575 """
576 if not self._accept_sockets and not self._sockets:
577 return
578
579 await self.to_server_pass()
580 await self.to_client_pass()
581
582 await self.stop_accepting()
583 logger.info('Before close() %s', self.info())
584 await self.sockets_close()
585 assert not self._sockets
586
587 for sock in self._accept_sockets:
588 sock.close()
589 self._accept_sockets.clear()
590 logger.info('Stopped. %s', self.info())
591
592 async def sockets_close(
593 self,
594 *,
595 count: int | None = None,
596 ) -> None:
597 """Close all the connection going through the gate"""
598 for x in list(self._sockets)[0:count]:
599 await x.shutdown()
600 self._collect_garbage()
601
602 def get_sockname_for_clients(self) -> Address:
603 """
604 Returns the client socket address that the gate listens on.
605
606 This function allows to use 0 in GateRoute.port_for_client and retrieve
607 the actual port and host.
608 """
609 assert self._route.port_for_client != 0, ('Gate was not started and the port_for_client is still 0',)
610 return (self._route.host_for_client, self._route.port_for_client)
611
612 def info(self) -> str:
613 """Print info on open sockets"""
614 if not self._sockets:
615 return f'"{self._route.name}" no active sockets'
616
617 return f'"{self._route.name}" active sockets:\n\t' + '\n\t'.join(x.info() for x in self._sockets)
618
619 def _collect_garbage(self) -> None:
620 self._sockets = {x for x in self._sockets if x.is_active()}
621
622 async def _do_accept(self, accept_sock: Socket) -> None:
623 """
624 This task should wait for connections and create SocketPair
625 """
626 raise NotImplementedError(self._NOT_IMPLEMENTED_MESSAGE)
627
628 async def set_to_server_interceptor(self, interceptor: Interceptor) -> callinfo.AsyncCallQueue:
629 """
630 Replace existing interceptor of client to server data with a custom
631 """
632 self._to_server_intercept = _create_callqueue(interceptor)
633 for x in self._sockets:
634 await x.set_to_server_interceptor(self._to_server_intercept)
635 return self._to_server_intercept
636
637 async def set_to_client_interceptor(self, interceptor: Interceptor) -> callinfo.AsyncCallQueue:
638 """
639 Replace existing interceptor of server to client data with a custom
640
641 """
642 if interceptor is not None:
643 self._to_client_intercept = _create_callqueue(interceptor)
644 else:
645 self._to_client_intercept = None
646 for x in self._sockets:
647 await x.set_to_client_interceptor(self._to_client_intercept)
648 return self._to_client_intercept
649
650 async def to_server_pass(self) -> callinfo.AsyncCallQueue:
651 """Pass data as is"""
652 logging.trace('to_server_pass')
653 return await self.set_to_server_interceptor(_intercept_ok)
654
655 async def to_client_pass(self) -> callinfo.AsyncCallQueue:
656 """Pass data as is"""
657 logging.trace('to_client_pass')
658 return await self.set_to_client_interceptor(_intercept_ok)
659
660 async def to_server_noop(self) -> callinfo.AsyncCallQueue:
661 """Do not read data, causing client to keep multiple data"""
662 logging.trace('to_server_noop')
663 return await self.set_to_server_interceptor(None)
664
665 async def to_client_noop(self) -> callinfo.AsyncCallQueue:
666 """Do not read data, causing server to keep multiple data"""
667 logging.trace('to_client_noop')
668 return await self.set_to_client_interceptor(None)
669
670 async def to_server_drop(self) -> callinfo.AsyncCallQueue:
671 """Read and discard data"""
672 logging.trace('to_server_drop')
673 return await self.set_to_server_interceptor(_intercept_drop)
674
675 async def to_client_drop(self) -> callinfo.AsyncCallQueue:
676 """Read and discard data"""
677 logging.trace('to_client_drop')
678 return await self.set_to_client_interceptor(_intercept_drop)
679
680 async def to_server_delay(self, delay: float) -> callinfo.AsyncCallQueue:
681 """Delay data transmission"""
682 logging.trace('to_server_delay, delay: %s', delay)
683
684 async def _intercept_delay_bound(
685 loop: EvLoop,
686 socket_from: Socket,
687 socket_to: Socket,
688 ) -> None:
689 await _intercept_delay(delay, loop, socket_from, socket_to)
690
691 return await self.set_to_server_interceptor(_intercept_delay_bound)
692
693 async def to_client_delay(self, delay: float) -> callinfo.AsyncCallQueue:
694 """Delay data transmission"""
695 logging.trace('to_client_delay, delay: %s', delay)
696
697 async def _intercept_delay_bound(
698 loop: EvLoop,
699 socket_from: Socket,
700 socket_to: Socket,
701 ) -> None:
702 await _intercept_delay(delay, loop, socket_from, socket_to)
703
704 return await self.set_to_client_interceptor(_intercept_delay_bound)
705
706 async def to_server_close_on_data(self) -> callinfo.AsyncCallQueue:
707 """Close on first bytes of data from client"""
708 logging.trace('to_server_close_on_data')
709 return await self.set_to_server_interceptor(_intercept_close_on_data)
710
711 async def to_client_close_on_data(self) -> callinfo.AsyncCallQueue:
712 """Close on first bytes of data from server"""
713 logging.trace('to_client_close_on_data')
714 return await self.set_to_client_interceptor(_intercept_close_on_data)
715
716 async def to_server_corrupt_data(self) -> callinfo.AsyncCallQueue:
717 """Corrupt data received from client"""
718 logging.trace('to_server_corrupt_data')
719 return await self.set_to_server_interceptor(_intercept_corrupt)
720
721 async def to_client_corrupt_data(self) -> callinfo.AsyncCallQueue:
722 """Corrupt data received from server"""
723 logging.trace('to_client_corrupt_data')
724 return await self.set_to_client_interceptor(_intercept_corrupt)
725
726 async def to_server_limit_bps(self, bytes_per_second: float) -> callinfo.AsyncCallQueue:
727 """Limit bytes per second transmission by network from client"""
728 logging.trace(
729 'to_server_limit_bps, bytes_per_second: %s',
730 bytes_per_second,
731 )
732 return await self.set_to_server_interceptor(_InterceptBpsLimit(bytes_per_second))
733
734 async def to_client_limit_bps(self, bytes_per_second: float) -> callinfo.AsyncCallQueue:
735 """Limit bytes per second transmission by network from server"""
736 logging.trace(
737 'to_client_limit_bps, bytes_per_second: %s',
738 bytes_per_second,
739 )
740 return await self.set_to_client_interceptor(_InterceptBpsLimit(bytes_per_second))
741
742 async def to_server_limit_time(self, timeout: float, jitter: float) -> callinfo.AsyncCallQueue:
743 """Limit connection lifetime on receive of first bytes from client"""
744 logging.trace(
745 'to_server_limit_time, timeout: %s, jitter: %s',
746 timeout,
747 jitter,
748 )
749 return await self.set_to_server_interceptor(_InterceptTimeLimit(timeout, jitter))
750
751 async def to_client_limit_time(self, timeout: float, jitter: float) -> callinfo.AsyncCallQueue:
752 """Limit connection lifetime on receive of first bytes from server"""
753 logging.trace(
754 'to_client_limit_time, timeout: %s, jitter: %s',
755 timeout,
756 jitter,
757 )
758 return await self.set_to_client_interceptor(_InterceptTimeLimit(timeout, jitter))
759
760 async def to_server_smaller_parts(
761 self,
762 max_size: int,
763 *,
764 sleep_per_packet: float = 0,
765 ) -> callinfo.AsyncCallQueue:
766 """
767 Pass data to server in smaller parts
768
769 @param max_size Max packet size to send to server
770 @param sleep_per_packet Optional sleep interval per packet, seconds
771 """
772 logging.trace('to_server_smaller_parts, max_size: %s', max_size)
773 return await self.set_to_server_interceptor(
774 _InterceptSmallerParts(max_size, sleep_per_packet),
775 )
776
777 async def to_client_smaller_parts(
778 self,
779 max_size: int,
780 *,
781 sleep_per_packet: float = 0,
782 ) -> callinfo.AsyncCallQueue:
783 """
784 Pass data to client in smaller parts
785
786 @param max_size Max packet size to send to client
787 @param sleep_per_packet Optional sleep interval per packet, seconds
788 """
789 logging.trace('to_client_smaller_parts, max_size: %s', max_size)
790 return await self.set_to_client_interceptor(
791 _InterceptSmallerParts(max_size, sleep_per_packet),
792 )
793
794 async def to_server_concat_packets(self, packet_size: int) -> callinfo.AsyncCallQueue:
795 """
796 Pass data in bigger parts
797 @param packet_size minimal size of the resulting packet
798 """
799 logging.trace('to_server_concat_packets, packet_size: %s', packet_size)
800 return await self.set_to_server_interceptor(_InterceptConcatPackets(packet_size))
801
802 async def to_client_concat_packets(self, packet_size: int) -> callinfo.AsyncCallQueue:
803 """
804 Pass data in bigger parts
805 @param packet_size minimal size of the resulting packet
806 """
807 logging.trace('to_client_concat_packets, packet_size: %s', packet_size)
808 return await self.set_to_client_interceptor(_InterceptConcatPackets(packet_size))
809
810 async def to_server_limit_bytes(self, bytes_limit: int) -> callinfo.AsyncCallQueue:
811 """Drop all connections each `bytes_limit` of data sent by network"""
812 logging.trace('to_server_limit_bytes, bytes_limit: %s', bytes_limit)
813 return await self.set_to_server_interceptor(_InterceptBytesLimit(bytes_limit, self))
814
815 async def to_client_limit_bytes(self, bytes_limit: int) -> callinfo.AsyncCallQueue:
816 """Drop all connections each `bytes_limit` of data sent by network"""
817 logging.trace('to_client_limit_bytes, bytes_limit: %s', bytes_limit)
818 return await self.set_to_client_interceptor(_InterceptBytesLimit(bytes_limit, self))
819
820 async def to_server_substitute(self, pattern: str, repl: str) -> callinfo.AsyncCallQueue:
821 """Apply regex substitution to data from client"""
822 logging.trace(
823 'to_server_substitute, pattern: %s, repl: %s',
824 pattern,
825 repl,
826 )
827 return await self.set_to_server_interceptor(_InterceptSubstitute(pattern, repl))
828
829 async def to_client_substitute(self, pattern: str, repl: str) -> callinfo.AsyncCallQueue:
830 """Apply regex substitution to data from server"""
831 logging.trace(
832 'to_client_substitute, pattern: %s, repl: %s',
833 pattern,
834 repl,
835 )
836 return await self.set_to_client_interceptor(_InterceptSubstitute(pattern, repl))
837
838
839class TcpGate(BaseGate):
840 """
841 Implements TCP chaos-proxy logic such as accepting incoming tcp client
842 connections. On each new connection new tcp client connects to server
843 (host_to_server, port_to_server).
844
845 @ingroup userver_testsuite
846
847 @see @ref scripts/docs/en/userver/chaos_testing.md
848 """
849
850 def __init__(self, route: GateRoute, loop: EvLoop | None = None) -> None:
851 self._connected_event = asyncio.Event()
852 super().__init__(route, loop)
853
854 def connections_count(self) -> int:
855 """
856 Returns maximal amount of connections going through the gate at
857 the moment.
858
859 @warning Some of the connections could be closing, or could be opened
860 right before the function starts. Use with caution!
861 """
862 return len(self._sockets)
863
864 async def wait_for_connections(self, *, count=1, timeout=0.0) -> None:
865 """
866 Wait for at least `count` connections going through the gate.
867
868 @throws asyncio.TimeoutError exception if failed to get the
869 required amount of connections in time.
870 """
871 if timeout <= 0.0:
872 while self.connections_count() < count:
873 await self._connected_event.wait()
874 self._connected_event.clear()
875 return
876
877 deadline = time.monotonic() + timeout
878 while self.connections_count() < count:
879 time_left = deadline - time.monotonic()
880 await asyncio.wait_for(
881 self._connected_event.wait(),
882 timeout=time_left,
883 )
884 self._connected_event.clear()
885
886 def _create_accepting_sockets(self) -> list[Socket]:
887 res: list[Socket] = []
888 for addr in socket.getaddrinfo(
889 self._route.host_for_client,
890 self._route.port_for_client,
891 type=socket.SOCK_STREAM,
892 ):
893 sock = Socket(addr[0], addr[1])
894 _make_socket_nonblocking(sock)
895 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
896 sock.bind(addr[4])
897 sock.listen()
898 logger.debug(
899 f'Accepting connections on {sock.getsockname()}, fd={sock.fileno()}',
900 )
901 res.append(sock)
902
903 return res
904
905 async def _connect_to_server(self):
906 addrs = await self._loop.getaddrinfo(
907 self._route.host_to_server,
908 self._route.port_to_server,
909 type=socket.SOCK_STREAM,
910 )
911 for addr in addrs:
912 server = Socket(addr[0], addr[1])
913 _make_socket_nonblocking(server)
914 try:
915 await self._loop.sock_connect(server, addr[4])
916 logging.trace('Connected to %s', addr[4])
917 return server
918 except Exception as exc: # pylint: disable=broad-except
919 server.close()
920 logging.warning('Could not connect to %s: %s', addr[4], exc)
921
922 async def _do_accept(self, accept_sock: Socket) -> None:
923 while True:
924 client, _ = await self._loop.sock_accept(accept_sock)
925 _make_socket_nonblocking(client)
926
927 server = await self._connect_to_server()
928 if server:
929 self._sockets.add(
930 _SocketsPaired(
931 self._route.name,
932 self._loop,
933 client,
934 server,
935 self._to_server_intercept,
936 self._to_client_intercept,
937 ),
938 )
939 self._connected_event.set()
940 else:
941 client.close()
942
943 self._collect_garbage()
944
945
946class UdpGate(BaseGate):
947 """
948 Implements UDP chaos-proxy logic such as demuxing incoming datagrams
949 from different clients.
950 Separate connections to server are made for each new client.
951
952 @ingroup userver_testsuite
953
954 @see @ref scripts/docs/en/userver/chaos_testing.md
955 """
956
957 def __init__(self, route: GateRoute, loop: EvLoop | None = None):
958 self._clients: set[_UdpDemuxSocketMock] = set()
959 super().__init__(route, loop)
960
961 def is_connected(self) -> bool:
962 """
963 Returns True if there is active pair of sockets ready to transfer data
964 at the moment.
965 """
966 return len(self._sockets) > 0
967
968 def _create_accepting_sockets(self) -> list[Socket]:
969 res: list[Socket] = []
970 for addr in socket.getaddrinfo(
971 self._route.host_for_client,
972 self._route.port_for_client,
973 type=socket.SOCK_DGRAM,
974 ):
975 sock = socket.socket(addr[0], addr[1])
976 _make_socket_nonblocking(sock)
977 sock.bind(addr[4])
978 logger.debug(f'Accepting connections on {sock.getsockname()}')
979 res.append(sock)
980
981 return res
982
983 async def _connect_to_server(self):
984 addrs = await self._loop.getaddrinfo(
985 self._route.host_to_server,
986 self._route.port_to_server,
987 type=socket.SOCK_DGRAM,
988 )
989 for addr in addrs:
990 server = Socket(addr[0], addr[1])
991 try:
992 _make_socket_nonblocking(server)
993 await self._loop.sock_connect(server, addr[4])
994 logging.trace('Connected to %s', addr[4])
995 return server
996 except Exception as exc: # pylint: disable=broad-except
997 logging.warning('Could not connect to %s: %s', addr[4], exc)
998
999 def _collect_garbage(self) -> None:
1000 super()._collect_garbage()
1001 self._clients = {c for c in self._clients if c.is_active()}
1002
1003 async def _do_accept(self, accept_sock: Socket):
1004 sock = asyncio_socket.from_socket(accept_sock)
1005 while True:
1006 data, addr = await sock.recvfrom(RECV_MAX_SIZE, timeout=60.0)
1007
1008 client: _UdpDemuxSocketMock | None = None
1009 for known_clients in self._clients:
1010 if addr == known_clients.peer_address:
1011 client = known_clients
1012 break
1013
1014 if client is None:
1015 server = await self._connect_to_server()
1016 if not server:
1017 accept_sock.close()
1018 break
1019
1020 client = _UdpDemuxSocketMock(accept_sock, addr)
1021 self._clients.add(client)
1022
1023 self._sockets.add(
1024 _SocketsPaired(
1025 self._route.name,
1026 self._loop,
1027 client,
1028 server,
1029 self._to_server_intercept,
1030 self._to_client_intercept,
1031 ),
1032 )
1033
1034 await client.push(self._loop, data)
1035 self._collect_garbage()
1036
1037 async def to_server_concat_packets(self, packet_size: int) -> None:
1038 raise NotImplementedError('Udp packets cannot be concatenated')
1039
1040 async def to_client_concat_packets(self, packet_size: int) -> None:
1041 raise NotImplementedError('Udp packets cannot be concatenated')
1042
1043 async def to_server_smaller_parts(
1044 self,
1045 max_size: int,
1046 *,
1047 sleep_per_packet: float = 0,
1048 ) -> None:
1049 raise NotImplementedError('Udp packets cannot be split')
1050
1051 async def to_client_smaller_parts(
1052 self,
1053 max_size: int,
1054 *,
1055 sleep_per_packet: float = 0,
1056 ) -> None:
1057 raise NotImplementedError('Udp packets cannot be split')
1058
1059
1060def _create_callqueue(obj):
1061 if obj is None:
1062 return None
1063
1064 # workaround testsuite acallqueue that does not work with instances
1065 if isinstance(obj, callinfo.AsyncCallQueue):
1066 return obj
1067 if hasattr(obj, '__name__'):
1068 return callinfo.acallqueue(obj)
1069
1070 @functools.wraps(obj)
1071 async def wrapper(*args, **kwargs):
1072 return await obj(*args, **kwargs)
1073
1074 return callinfo.acallqueue(wrapper)
1075
1076
1077async def _wait_for_data(sock, timeout=60.0):
1078 if isinstance(sock, _UdpDemuxSocketMock):
1079 sock = sock.get_demux_out()
1080 sock = asyncio_socket.from_socket(sock)
1081 await sock.wait_for_data(timeout=timeout)