userver: /data/code/userver/testsuite/pytest_plugins/pytest_userver/chaos.py Source File
Loading...
Searching...
No Matches
chaos.py
1# pylint: disable=too-many-lines
2"""
3Python module that provides testsuite support for
4chaos tests; see
5@ref scripts/docs/en/userver/chaos_testing.md for an introduction.
6
7@ingroup userver_testsuite
8"""
9
10import asyncio
11import dataclasses
12import fcntl
13import logging
14import os
15import random
16import re
17import socket
18import sys
19import time
20import typing
21
22
23@dataclasses.dataclass(frozen=True)
25 """
26 Class that describes the route for TcpGate or UdpGate.
27
28 Use `port_for_client == 0` to bind to some unused port. In that case the
29 actual address could be retrieved via BaseGate.get_sockname_for_clients().
30
31 @ingroup userver_testsuite
32 """
33
34 name: str
35 host_to_server: str
36 port_to_server: int
37 host_for_client: str = '127.0.0.1'
38 port_for_client: int = 0
39
40
41# @cond
42
43# https://docs.python.org/3/library/socket.html#socket.socket.recv
44RECV_MAX_SIZE = 4096
45MAX_DELAY = 60.0
46
47
48logger = logging.getLogger(__name__)
49
50
51Address = typing.Tuple[str, int]
52EvLoop = typing.Any
53Socket = socket.socket
54Interceptor = typing.Callable[
55 [EvLoop, Socket, Socket],
56 typing.Coroutine[typing.Any, typing.Any, None],
57]
58
59
60class GateException(Exception):
61 pass
62
63
64class GateInterceptException(Exception):
65 pass
66
67
68async def _yield() -> None:
69 # Minamal delay can be 0. This will be fast path for coroutine switching
70 # https://docs.python.org/3/library/asyncio-task.html#sleeping
71
72 min_delay = 0
73 await asyncio.sleep(min_delay)
74
75
76def _try_get_message(
77 recv_socket: Socket,
78 flags: int,
79) -> typing.Tuple[typing.Optional[bytes], typing.Optional[Address]]:
80 try:
81 return recv_socket.recvfrom(RECV_MAX_SIZE, flags)
82 except (BlockingIOError, InterruptedError):
83 return None, None
84
85
86async def _get_message_task(
87 recv_socket: Socket,
88) -> typing.Tuple[bytes, Address]:
89 while True:
90 msg, addr = _try_get_message(recv_socket, 0)
91 if msg:
92 assert addr
93 return msg, addr
94
95 await _yield()
96
97
98def _incoming_data_size(recv_socket: Socket) -> int:
99 msg, _ = _try_get_message(recv_socket, socket.MSG_PEEK)
100 return len(msg) if msg else 0
101
102
103async def _intercept_ok(
104 loop: EvLoop,
105 socket_from: Socket,
106 socket_to: Socket,
107) -> None:
108 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
109 await loop.sock_sendall(socket_to, data)
110
111
112async def _intercept_noop(
113 loop: EvLoop,
114 socket_from: Socket,
115 socket_to: Socket,
116) -> None:
117 pass
118
119
120async def _intercept_drop(
121 loop: EvLoop,
122 socket_from: Socket,
123 socket_to: Socket,
124) -> None:
125 await loop.sock_recv(socket_from, RECV_MAX_SIZE)
126
127
128async def _intercept_delay(
129 delay: float,
130 loop: EvLoop,
131 socket_from: Socket,
132 socket_to: Socket,
133) -> None:
134 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
135 await asyncio.sleep(delay)
136 await loop.sock_sendall(socket_to, data)
137
138
139async def _intercept_close_on_data(
140 loop: EvLoop,
141 socket_from: Socket,
142 socket_to: Socket,
143) -> None:
144 await loop.sock_recv(socket_from, 1)
145 raise GateInterceptException('Closing socket on data')
146
147
148async def _intercept_corrupt(
149 loop: EvLoop,
150 socket_from: Socket,
151 socket_to: Socket,
152) -> None:
153 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
154 await loop.sock_sendall(socket_to, bytearray([not x for x in data]))
155
156
157class _InterceptBpsLimit:
158 def __init__(self, bytes_per_second: float):
159 assert bytes_per_second >= 1
160 self._bytes_per_second = bytes_per_second
161 self._time_last_added = 0.0
162 self._bytes_left = self._bytes_per_second
163
164 def _update_limit(self) -> None:
165 current_time = time.monotonic()
166 elapsed = current_time - self._time_last_added
167 bytes_addition = self._bytes_per_second * elapsed
168 if bytes_addition > 0:
169 self._bytes_left += bytes_addition
170 self._time_last_added = current_time
171
172 if self._bytes_left > self._bytes_per_second:
173 self._bytes_left = self._bytes_per_second
174
175 async def __call__(
176 self,
177 loop: EvLoop,
178 socket_from: Socket,
179 socket_to: Socket,
180 ) -> None:
181 self._update_limit()
182
183 bytes_to_recv = min(int(self._bytes_left), RECV_MAX_SIZE)
184 if bytes_to_recv > 0:
185 data = await loop.sock_recv(socket_from, bytes_to_recv)
186 self._bytes_left -= len(data)
187
188 await loop.sock_sendall(socket_to, data)
189 else:
190 logger.info('Socket hits the bytes per second limit')
191 await asyncio.sleep(1.0 / self._bytes_per_second)
192
193
194class _InterceptTimeLimit:
195 def __init__(self, timeout: float, jitter: float):
196 self._sockets: typing.Dict[Socket, float] = {}
197 assert timeout >= 0.0
198 self._timeout = timeout
199 assert jitter >= 0.0
200 self._jitter = jitter
201
202 def raise_if_timed_out(self, socket_from: Socket) -> None:
203 if socket_from not in self._sockets:
204 jitter = self._jitter * random.random()
205 expire_at = time.monotonic() + self._timeout + jitter
206 self._sockets[socket_from] = expire_at
207
208 if self._sockets[socket_from] <= time.monotonic():
209 del self._sockets[socket_from]
210 raise GateInterceptException('Socket hits the time limit')
211
212 async def __call__(
213 self,
214 loop: EvLoop,
215 socket_from: Socket,
216 socket_to: Socket,
217 ) -> None:
218 self.raise_if_timed_out(socket_from)
219 await _intercept_ok(loop, socket_from, socket_to)
220
221
222class _InterceptSmallerParts:
223 def __init__(self, max_size: int, sleep_per_packet: float):
224 assert max_size > 0
225 self._max_size = max_size
226 self._sleep_per_packet = sleep_per_packet
227
228 async def __call__(
229 self,
230 loop: EvLoop,
231 socket_from: Socket,
232 socket_to: Socket,
233 ) -> None:
234 incoming_size = _incoming_data_size(socket_from)
235 chunk_size = min(incoming_size, self._max_size)
236 data = await loop.sock_recv(socket_from, chunk_size)
237 await asyncio.sleep(self._sleep_per_packet)
238 await loop.sock_sendall(socket_to, data)
239
240
241class _InterceptConcatPackets:
242 def __init__(self, packet_size: int):
243 assert packet_size >= 0
244 self._packet_size = packet_size
245 self._expire_at: typing.Optional[float] = None
246
247 async def __call__(
248 self,
249 loop: EvLoop,
250 socket_from: Socket,
251 socket_to: Socket,
252 ) -> None:
253 if self._expire_at is None:
254 self._expire_at = time.monotonic() + MAX_DELAY
255
256 if self._expire_at <= time.monotonic():
257 logger.error(
258 f'Failed to make a packet of sufficient size in {MAX_DELAY} '
259 'seconds. Check the test logic, it should end with checking '
260 'that the data was sent and by calling TcpGate function '
261 'to_client_pass() to pass the remaining packets.',
262 )
263 sys.exit(2)
264
265 incoming_size = _incoming_data_size(socket_from)
266 if incoming_size >= self._packet_size:
267 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
268 await loop.sock_sendall(socket_to, data)
269 self._expire_at = None
270
271
272class _InterceptBytesLimit:
273 def __init__(self, bytes_limit: int, gate: 'BaseGate'):
274 assert bytes_limit >= 0
275 self._bytes_limit = bytes_limit
276 self._bytes_remain = self._bytes_limit
277 self._gate = gate
278
279 async def __call__(
280 self,
281 loop: EvLoop,
282 socket_from: Socket,
283 socket_to: Socket,
284 ) -> None:
285 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
286 if self._bytes_remain <= len(data):
287 await loop.sock_sendall(socket_to, data[0 : self._bytes_remain])
288 await self._gate.sockets_close()
289 self._bytes_remain = self._bytes_limit
290 raise GateInterceptException('Data transmission limit reached')
291
292 self._bytes_remain -= len(data)
293 await loop.sock_sendall(socket_to, data)
294
295
296class _InterceptSubstitute:
297 def __init__(self, pattern: str, repl: str, encoding='utf-8'):
298 self._pattern = re.compile(pattern)
299 self._repl = repl
300 self._encoding = encoding
301
302 async def __call__(
303 self,
304 loop: EvLoop,
305 socket_from: Socket,
306 socket_to: Socket,
307 ) -> None:
308 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
309 try:
310 res = self._pattern.sub(self._repl, data.decode(self._encoding))
311 data = res.encode(self._encoding)
312 except UnicodeError:
313 pass
314 await loop.sock_sendall(socket_to, data)
315
316
317async def _cancel_and_join(task: typing.Optional[asyncio.Task]) -> None:
318 if not task or task.cancelled():
319 return
320
321 try:
322 task.cancel()
323 await task
324 except asyncio.CancelledError:
325 return
326 except Exception as exc: # pylint: disable=broad-except
327 logger.error('Exception in _cancel_and_join: %s', exc)
328
329
330def _make_socket_nonblocking(sock: Socket) -> None:
331 sock.setblocking(False)
332 if sock.type == socket.SOCK_STREAM:
333 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
334 fcntl.fcntl(sock, fcntl.F_SETFL, os.O_NONBLOCK)
335
336
337class _UdpDemuxSocketMock:
338 """
339 Emulates a point-to-point connection over UDP socket
340 with a non-blocking socket interface
341 """
342
343 def __init__(self, sock: Socket, peer_address: Address):
344 self._sock: Socket = sock
345 self._peeraddr: Address = peer_address
346
347 sockpair = socket.socketpair(type=socket.SOCK_DGRAM)
348 self._demux_in: Socket = sockpair[0]
349 self._demux_out: Socket = sockpair[1]
350 _make_socket_nonblocking(self._demux_in)
351 _make_socket_nonblocking(self._demux_out)
352 self._is_active: bool = True
353
354 @property
355 def peer_address(self):
356 return self._peeraddr
357
358 async def push(self, loop: EvLoop, data: bytes):
359 return await loop.sock_sendall(self._demux_in, data)
360
361 def is_active(self):
362 return self._is_active
363
364 def close(self):
365 self._is_active = False
366 self._demux_out.close()
367 self._demux_in.close()
368
369 def recvfrom(self, bufsize: int, flags: int = 0):
370 return self._demux_out.recvfrom(bufsize, flags)
371
372 def recv(self, bufsize: int, flags: int = 0):
373 return self._demux_out.recv(bufsize, flags)
374
375 def fileno(self):
376 return self._demux_out.fileno()
377
378 def send(self, data: bytes):
379 return self._sock.sendto(data, self._peeraddr)
380
381
382class _SocketsPaired:
383 def __init__(
384 self,
385 proxy_name: str,
386 loop: EvLoop,
387 client: typing.Union[socket.socket, _UdpDemuxSocketMock],
388 server: socket.socket,
389 to_server_intercept: Interceptor,
390 to_client_intercept: Interceptor,
391 ) -> None:
392 self._proxy_name = proxy_name
393 self._loop = loop
394
395 self._client = client
396 self._server = server
397
398 self._to_server_intercept: Interceptor = to_server_intercept
399 self._to_client_intercept: Interceptor = to_client_intercept
400
401 self._task_to_server = asyncio.create_task(
402 self._do_pipe_channels(to_server=True),
403 )
404 self._task_to_client = asyncio.create_task(
405 self._do_pipe_channels(to_server=False),
406 )
407
408 self._finished_channels = 0
409
410 async def _do_pipe_channels(self, *, to_server: bool) -> None:
411 if to_server:
412 socket_from = self._client
413 socket_to = self._server
414 else:
415 socket_from = self._server
416 socket_to = self._client
417
418 try:
419 while True:
420 # Applies new interceptors faster.
421 #
422 # To avoid long awaiting on sock_recv in an outdated
423 # interceptor we wait for data before grabbing and applying
424 # the interceptor.
425 if not _incoming_data_size(socket_from):
426 await _yield()
427 continue
428
429 if to_server:
430 interceptor = self._to_server_intercept
431 else:
432 interceptor = self._to_client_intercept
433
434 await interceptor(self._loop, socket_from, socket_to)
435 await _yield()
436 except GateInterceptException as exc:
437 logger.info('In "%s": %s', self._proxy_name, exc)
438 except socket.error as exc:
439 logger.error('Exception in "%s": %s', self._proxy_name, exc)
440 finally:
441 self._finished_channels += 1
442 if self._finished_channels == 2:
443 # Closing the sockets here so that the self.shutdown()
444 # returns only when the sockets are actually closed
445 logger.info('"%s" closes %s', self._proxy_name, self.info())
446 self._close_socket(self._client)
447 self._close_socket(self._server)
448 else:
449 assert self._finished_channels == 1
450 if to_server:
451 self._task_to_client.cancel()
452 else:
453 self._task_to_server.cancel()
454
455 def set_to_server_interceptor(self, interceptor: Interceptor) -> None:
456 self._to_server_intercept = interceptor
457
458 def set_to_client_interceptor(self, interceptor: Interceptor) -> None:
459 self._to_client_intercept = interceptor
460
461 def _close_socket(self, self_socket: Socket) -> None:
462 assert self_socket in {self._client, self._server}
463 try:
464 self_socket.close()
465 except socket.error as exc:
466 logger.error(
467 'Exception in "%s" on closing %s: %s',
468 self._proxy_name,
469 'client' if self_socket == self._client else 'server',
470 exc,
471 )
472
473 async def shutdown(self) -> None:
474 for task in {self._task_to_client, self._task_to_server}:
475 await _cancel_and_join(task)
476
477 def is_active(self) -> bool:
478 return not self._task_to_client.done() or not self._task_to_server.done()
479
480 def info(self) -> str:
481 if not self.is_active():
482 return '<inactive>'
483
484 return f'client fd={self._client.fileno()} <=> ' f'server fd={self._server.fileno()}'
485
486
487# @endcond
488
489
491 """
492 This base class maintain endpoints of two types:
493
494 Server-side endpoints to receive messages from clients. Address of this
495 endpoint is described by (host_for_client, port_for_client).
496
497 Client-side endpoints to forward messages to server. Server must listen on
498 (host_to_server, port_to_server).
499
500 Asynchronously concurrently passes data from client to server and from
501 server to client, allowing intercepting the data, injecting delays and
502 dropping connections.
503
504 @warning Do not use this class itself. Use one of the specifications
505 TcpGate or UdpGate
506
507 @ingroup userver_testsuite
508
509 @see @ref scripts/docs/en/userver/chaos_testing.md
510 """
511
512 _NOT_IMPLEMENTED_MESSAGE = 'Do not use BaseGate itself, use one of ' 'specializations TcpGate or UdpGate'
513
514 def __init__(self, route: GateRoute, loop: EvLoop) -> None:
515 self._route = route
516 self._loop = loop
517
518 self._to_server_intercept: Interceptor = _intercept_ok
519 self._to_client_intercept: Interceptor = _intercept_ok
520
521 self._accept_sockets: typing.List[socket.socket] = []
522 self._accept_tasks: typing.List[asyncio.Task[None]] = []
523
524 self._sockets: typing.Set[_SocketsPaired] = set()
525
526 async def __aenter__(self) -> 'BaseGate':
527 self.start()
528 return self
529
530 async def __aexit__(self, exc_type, exc_value, traceback) -> None:
531 await self.stop()
532
533 def _create_accepting_sockets(self) -> typing.List[Socket]:
534 raise NotImplementedError(self._NOT_IMPLEMENTED_MESSAGE_NOT_IMPLEMENTED_MESSAGE)
535
536 def start(self):
537 """Open the socket and start accepting tasks"""
538 if self._accept_sockets:
539 return
540
541 self._accept_sockets.extend(self._create_accepting_sockets())
542
543 if not self._accept_sockets:
544 raise GateException(
545 f'Could not resolve hostname {self._route.host_for_client}',
546 )
547
548 if self._route.port_for_client == 0:
549 # In case of stop()+start() bind to the same port
550 self._route = GateRoute(
551 name=self._route.name,
552 host_to_server=self._route.host_to_server,
553 port_to_server=self._route.port_to_server,
554 host_for_client=self._accept_sockets[0].getsockname()[0],
555 port_for_client=self._accept_sockets[0].getsockname()[1],
556 )
557
558 BaseGate.start_accepting(self)
559
560 def start_accepting(self) -> None:
561 """Start accepting tasks"""
562 assert self._accept_sockets
563 if not all(tsk.done() for tsk in self._accept_tasks):
564 return
565
566 self._accept_tasks.clear()
567 for sock in self._accept_sockets:
568 self._accept_tasks.append(
569 asyncio.create_task(self._do_accept(sock)),
570 )
571
572 async def stop_accepting(self) -> None:
573 """
574 Stop accepting tasks without closing the accepting socket.
575 """
576 for tsk in self._accept_tasks:
577 await _cancel_and_join(tsk)
578 self._accept_tasks.clear()
579
580 async def stop(self) -> None:
581 """
582 Stop accepting tasks, close all the sockets
583 """
584 if not self._accept_sockets and not self._sockets:
585 return
586
587 self.to_server_pass()
588 self.to_client_pass()
589
590 await BaseGate.stop_accepting(self)
591 logger.info('Before close() %s', self.info())
592 await self.sockets_close()
593 assert not self._sockets
594
595 for sock in self._accept_sockets:
596 sock.close()
597 self._accept_sockets.clear()
598 logger.info('Stopped. %s', self.info())
599
600 async def sockets_close(
601 self,
602 *,
603 count: typing.Optional[int] = None,
604 ) -> None:
605 """Close all the connection going through the gate"""
606 for x in list(self._sockets)[0:count]:
607 await x.shutdown()
608 self._collect_garbage()
609
610 def get_sockname_for_clients(self) -> Address:
611 """
612 Returns the client socket address that the gate listens on.
613
614 This function allows to use 0 in GateRoute.port_for_client and retrieve
615 the actual port and host.
616 """
617 assert self._route.port_for_client != 0, ('Gate was not started and the port_for_client is still 0',)
618 return (self._route.host_for_client, self._route.port_for_client)
619
620 def info(self) -> str:
621 """Print info on open sockets"""
622 if not self._sockets:
623 return f'"{self._route.name}" no active sockets'
624
625 return f'"{self._route.name}" active sockets:\n\t' + '\n\t'.join(x.info() for x in self._sockets)
626
627 def _collect_garbage(self) -> None:
628 self._sockets = {x for x in self._sockets if x.is_active()}
629
630 async def _do_accept(self, accept_sock: Socket) -> None:
631 """
632 This task should wait for connections and create SocketPair
633 """
634 raise NotImplementedError(self._NOT_IMPLEMENTED_MESSAGE_NOT_IMPLEMENTED_MESSAGE)
635
636 def set_to_server_interceptor(self, interceptor: Interceptor) -> None:
637 """
638 Replace existing interceptor of client to server data with a custom
639 """
640 self._to_server_intercept = interceptor
641 for x in self._sockets:
642 x.set_to_server_interceptor(self._to_server_intercept)
643
644 def set_to_client_interceptor(self, interceptor: Interceptor) -> None:
645 """
646 Replace existing interceptor of server to client data with a custom
647 """
648 self._to_client_intercept = interceptor
649 for x in self._sockets:
650 x.set_to_client_interceptor(self._to_client_intercept)
651
652 def to_server_pass(self) -> None:
653 """Pass data as is"""
654 logging.trace('to_server_pass')
655 self.set_to_server_interceptor(_intercept_ok)
656
657 def to_client_pass(self) -> None:
658 """Pass data as is"""
659 logging.trace('to_client_pass')
660 self.set_to_client_interceptor(_intercept_ok)
661
662 def to_server_noop(self) -> None:
663 """Do not read data, causing client to keep multiple data"""
664 logging.trace('to_server_noop')
665 self.set_to_server_interceptor(_intercept_noop)
666
667 def to_client_noop(self) -> None:
668 """Do not read data, causing server to keep multiple data"""
669 logging.trace('to_client_noop')
670 self.set_to_client_interceptor(_intercept_noop)
671
672 def to_server_drop(self) -> None:
673 """Read and discard data"""
674 logging.trace('to_server_drop')
675 self.set_to_server_interceptor(_intercept_drop)
676
677 def to_client_drop(self) -> None:
678 """Read and discard data"""
679 logging.trace('to_client_drop')
680 self.set_to_client_interceptor(_intercept_drop)
681
682 def to_server_delay(self, delay: float) -> None:
683 """Delay data transmission"""
684 logging.trace('to_server_delay, delay: %s', delay)
685
686 async def _intercept_delay_bound(
687 loop: EvLoop,
688 socket_from: Socket,
689 socket_to: Socket,
690 ) -> None:
691 await _intercept_delay(delay, loop, socket_from, socket_to)
692
693 self.set_to_server_interceptor(_intercept_delay_bound)
694
695 def to_client_delay(self, delay: float) -> None:
696 """Delay data transmission"""
697 logging.trace('to_client_delay, delay: %s', delay)
698
699 async def _intercept_delay_bound(
700 loop: EvLoop,
701 socket_from: Socket,
702 socket_to: Socket,
703 ) -> None:
704 await _intercept_delay(delay, loop, socket_from, socket_to)
705
706 self.set_to_client_interceptor(_intercept_delay_bound)
707
708 def to_server_close_on_data(self) -> None:
709 """Close on first bytes of data from client"""
710 logging.trace('to_server_close_on_data')
711 self.set_to_server_interceptor(_intercept_close_on_data)
712
713 def to_client_close_on_data(self) -> None:
714 """Close on first bytes of data from server"""
715 logging.trace('to_client_close_on_data')
716 self.set_to_client_interceptor(_intercept_close_on_data)
717
718 def to_server_corrupt_data(self) -> None:
719 """Corrupt data received from client"""
720 logging.trace('to_server_corrupt_data')
721 self.set_to_server_interceptor(_intercept_corrupt)
722
723 def to_client_corrupt_data(self) -> None:
724 """Corrupt data received from server"""
725 logging.trace('to_client_corrupt_data')
726 self.set_to_client_interceptor(_intercept_corrupt)
727
728 def to_server_limit_bps(self, bytes_per_second: float) -> None:
729 """Limit bytes per second transmission by network from client"""
730 logging.trace(
731 'to_server_limit_bps, bytes_per_second: %s',
732 bytes_per_second,
733 )
734 self.set_to_server_interceptor(_InterceptBpsLimit(bytes_per_second))
735
736 def to_client_limit_bps(self, bytes_per_second: float) -> None:
737 """Limit bytes per second transmission by network from server"""
738 logging.trace(
739 'to_client_limit_bps, bytes_per_second: %s',
740 bytes_per_second,
741 )
742 self.set_to_client_interceptor(_InterceptBpsLimit(bytes_per_second))
743
744 def to_server_limit_time(self, timeout: float, jitter: float) -> None:
745 """Limit connection lifetime on receive of first bytes from client"""
746 logging.trace(
747 'to_server_limit_time, timeout: %s, jitter: %s',
748 timeout,
749 jitter,
750 )
751 self.set_to_server_interceptor(_InterceptTimeLimit(timeout, jitter))
752
753 def to_client_limit_time(self, timeout: float, jitter: float) -> None:
754 """Limit connection lifetime on receive of first bytes from server"""
755 logging.trace(
756 'to_client_limit_time, timeout: %s, jitter: %s',
757 timeout,
758 jitter,
759 )
760 self.set_to_client_interceptor(_InterceptTimeLimit(timeout, jitter))
761
762 def to_server_smaller_parts(
763 self,
764 max_size: int,
765 *,
766 sleep_per_packet: float = 0,
767 ) -> None:
768 """
769 Pass data to server in smaller parts
770
771 @param max_size Max packet size to send to server
772 @param sleep_per_packet Optional sleep interval per packet, seconds
773 """
774 logging.trace('to_server_smaller_parts, max_size: %s', max_size)
775 self.set_to_server_interceptor(
776 _InterceptSmallerParts(max_size, sleep_per_packet),
777 )
778
779 def to_client_smaller_parts(
780 self,
781 max_size: int,
782 *,
783 sleep_per_packet: float = 0,
784 ) -> None:
785 """
786 Pass data to client in smaller parts
787
788 @param max_size Max packet size to send to client
789 @param sleep_per_packet Optional sleep interval per packet, seconds
790 """
791 logging.trace('to_client_smaller_parts, max_size: %s', max_size)
792 self.set_to_client_interceptor(
793 _InterceptSmallerParts(max_size, sleep_per_packet),
794 )
795
796 def to_server_concat_packets(self, packet_size: int) -> None:
797 """
798 Pass data in bigger parts
799 @param packet_size minimal size of the resulting packet
800 """
801 logging.trace('to_server_concat_packets, packet_size: %s', packet_size)
802 self.set_to_server_interceptor(_InterceptConcatPackets(packet_size))
803
804 def to_client_concat_packets(self, packet_size: int) -> None:
805 """
806 Pass data in bigger parts
807 @param packet_size minimal size of the resulting packet
808 """
809 logging.trace('to_client_concat_packets, packet_size: %s', packet_size)
810 self.set_to_client_interceptor(_InterceptConcatPackets(packet_size))
811
812 def to_server_limit_bytes(self, bytes_limit: int) -> None:
813 """Drop all connections each `bytes_limit` of data sent by network"""
814 logging.trace('to_server_limit_bytes, bytes_limit: %s', bytes_limit)
815 self.set_to_server_interceptor(_InterceptBytesLimit(bytes_limit, self))
816
817 def to_client_limit_bytes(self, bytes_limit: int) -> None:
818 """Drop all connections each `bytes_limit` of data sent by network"""
819 logging.trace('to_client_limit_bytes, bytes_limit: %s', bytes_limit)
820 self.set_to_client_interceptor(_InterceptBytesLimit(bytes_limit, self))
821
822 def to_server_substitute(self, pattern: str, repl: str) -> None:
823 """Apply regex substitution to data from client"""
824 logging.trace(
825 'to_server_substitute, pattern: %s, repl: %s',
826 pattern,
827 repl,
828 )
829 self.set_to_server_interceptor(_InterceptSubstitute(pattern, repl))
830
831 def to_client_substitute(self, pattern: str, repl: str) -> None:
832 """Apply regex substitution to data from server"""
833 logging.trace(
834 'to_client_substitute, pattern: %s, repl: %s',
835 pattern,
836 repl,
837 )
838 self.set_to_client_interceptor(_InterceptSubstitute(pattern, repl))
839
840
841class TcpGate(BaseGate):
842 """
843 Implements TCP chaos-proxy logic such as accepting incoming tcp client
844 connections. On each new connection new tcp client connects to server
845 (host_to_server, port_to_server).
846
847 @ingroup userver_testsuite
848
849 @see @ref scripts/docs/en/userver/chaos_testing.md
850 """
851
852 def __init__(self, route: GateRoute, loop: EvLoop) -> None:
853 self._connected_event = asyncio.Event()
854 BaseGate.__init__(self, route, loop)
855
856 def connections_count(self) -> int:
857 """
858 Returns maximal amount of connections going through the gate at
859 the moment.
860
861 @warning Some of the connections could be closing, or could be opened
862 right before the function starts. Use with caution!
863 """
864 return len(self._sockets)
865
866 async def wait_for_connections(self, *, count=1, timeout=0.0) -> None:
867 """
868 Wait for at least `count` connections going through the gate.
869
870 @throws asyncio.TimeoutError exception if failed to get the
871 required amount of connections in time.
872 """
873 if timeout <= 0.0:
874 while self.connections_count() < count:
875 await self._connected_event.wait()
876 self._connected_event.clear()
877 return
878
879 deadline = time.monotonic() + timeout
880 while self.connections_count() < count:
881 time_left = deadline - time.monotonic()
882 await asyncio.wait_for(
883 self._connected_event.wait(),
884 timeout=time_left,
885 )
886 self._connected_event.clear()
887
888 def _create_accepting_sockets(self) -> typing.List[Socket]:
889 res: typing.List[Socket] = []
890 for addr in socket.getaddrinfo(
891 self._route.host_for_client,
892 self._route.port_for_client,
893 type=socket.SOCK_STREAM,
894 ):
895 sock = Socket(addr[0], addr[1])
896 _make_socket_nonblocking(sock)
897 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
898 sock.bind(addr[4])
899 sock.listen()
900 logger.debug(
901 f'Accepting connections on {sock.getsockname()}, ' f'fd={sock.fileno()}',
902 )
903 res.append(sock)
904
905 return res
906
907 async def _connect_to_server(self):
908 addrs = await self._loop.getaddrinfo(
909 self._route.host_to_server,
910 self._route.port_to_server,
911 type=socket.SOCK_STREAM,
912 )
913 for addr in addrs:
914 server = Socket(addr[0], addr[1])
915 try:
916 _make_socket_nonblocking(server)
917 await self._loop.sock_connect(server, addr[4])
918 logging.trace('Connected to %s', addr[4])
919 return server
920 except Exception as exc: # pylint: disable=broad-except
921 server.close()
922 logging.warning('Could not connect to %s: %s', addr[4], exc)
923
924 async def _do_accept(self, accept_sock: Socket) -> None:
925 while accept_sock:
926 client, _ = await self._loop.sock_accept(accept_sock)
927 _make_socket_nonblocking(client)
928
929 server = await self._connect_to_server()
930 if server:
931 self._sockets.add(
932 _SocketsPaired(
933 self._route.name,
934 self._loop,
935 client,
936 server,
937 self._to_server_intercept,
938 self._to_client_intercept,
939 ),
940 )
941 self._connected_event.set()
942 else:
943 client.close()
944
945 self._collect_garbage()
946
947
948class UdpGate(BaseGate):
949 """
950 Implements UDP chaos-proxy logic such as demuxing incoming datagrams
951 from different clients.
952 Separate connections to server are made for each new client.
953
954 @ingroup userver_testsuite
955
956 @see @ref scripts/docs/en/userver/chaos_testing.md
957 """
958
959 def __init__(self, route: GateRoute, loop: EvLoop):
960 self._clients: typing.Set[_UdpDemuxSocketMock] = set()
961 BaseGate.__init__(self, route, loop)
962
963 def is_connected(self) -> bool:
964 """
965 Returns True if there is active pair of sockets ready to transfer data
966 at the moment.
967 """
968 return len(self._sockets) > 0
969
970 def _create_accepting_sockets(self) -> typing.List[Socket]:
971 res: typing.List[Socket] = []
972 for addr in socket.getaddrinfo(
973 self._route.host_for_client,
974 self._route.port_for_client,
975 type=socket.SOCK_DGRAM,
976 ):
977 sock = socket.socket(addr[0], addr[1])
978 _make_socket_nonblocking(sock)
979 sock.bind(addr[4])
980 logger.debug(f'Accepting connections on {sock.getsockname()}')
981 res.append(sock)
982
983 return res
984
985 async def _connect_to_server(self):
986 addrs = await self._loop.getaddrinfo(
987 self._route.host_to_server,
988 self._route.port_to_server,
989 type=socket.SOCK_DGRAM,
990 )
991 for addr in addrs:
992 server = Socket(addr[0], addr[1])
993 try:
994 _make_socket_nonblocking(server)
995 await self._loop.sock_connect(server, addr[4])
996 logging.trace('Connected to %s', addr[4])
997 return server
998 except Exception as exc: # pylint: disable=broad-except
999 logging.warning('Could not connect to %s: %s', addr[4], exc)
1000
1001 def _collect_garbage(self) -> None:
1002 super()._collect_garbage()
1003 self._clients = {c for c in self._clients if c.is_active()}
1004
1005 async def _do_accept(self, accept_sock: Socket):
1006 while True:
1007 data, addr = await _get_message_task(accept_sock)
1008
1009 client: typing.Optional[_UdpDemuxSocketMock] = None
1010 for known_clients in self._clients:
1011 if addr == known_clients.peer_address:
1012 client = known_clients
1013 break
1014
1015 if client is None:
1016 server = await self._connect_to_server()
1017 if not server:
1018 accept_sock.close()
1019 break
1020
1021 client = _UdpDemuxSocketMock(accept_sock, addr)
1022 self._clients.add(client)
1023
1024 self._sockets.add(
1025 _SocketsPaired(
1026 self._route.name,
1027 self._loop,
1028 client,
1029 server,
1030 self._to_server_intercept,
1031 self._to_client_intercept,
1032 ),
1033 )
1034
1035 await client.push(self._loop, data)
1036 self._collect_garbage()
1037
1038 def to_server_concat_packets(self, packet_size: int) -> None:
1039 raise NotImplementedError('Udp packets cannot be concatenated')
1040
1041 def to_client_concat_packets(self, packet_size: int) -> None:
1042 raise NotImplementedError('Udp packets cannot be concatenated')
1043
1044 def to_server_smaller_parts(
1045 self,
1046 max_size: int,
1047 *,
1048 sleep_per_packet: float = 0,
1049 ) -> None:
1050 raise NotImplementedError('Udp packets cannot be split')
1051
1052 def to_client_smaller_parts(
1053 self,
1054 max_size: int,
1055 *,
1056 sleep_per_packet: float = 0,
1057 ) -> None:
1058 raise NotImplementedError('Udp packets cannot be split')