userver: testsuite/pytest_plugins/pytest_userver/chaos.py
Loading...
Searching...
No Matches
testsuite/pytest_plugins/pytest_userver/chaos.py
1# pylint: disable=too-many-lines
2"""
3Python module that provides testsuite support for
4chaos tests; see
5@ref scripts/docs/en/userver/chaos_testing.md for an introduction.
6
7@ingroup userver_testsuite
8"""
9
10import asyncio
11import dataclasses
12import fcntl
13import logging
14import os
15import random
16import re
17import select
18import socket
19import sys
20import time
21import typing
22
23
24@dataclasses.dataclass(frozen=True)
25class GateRoute:
26 """
27 Class that describes the route for TcpGate or UdpGate.
28
29 Use `port_for_client == 0` to bind to some unused port. In that case the
30 actual address could be retrieved via BaseGate.get_sockname_for_clients().
31
32 @ingroup userver_testsuite
33 """
34
35 name: str
36 host_to_server: str
37 port_to_server: int
38 host_for_client: str = '127.0.0.1'
39 port_for_client: int = 0
40
41
42# @cond
43
44# https://docs.python.org/3/library/socket.html#socket.socket.recv
45RECV_MAX_SIZE = 4096
46MAX_DELAY = 60.0
47
48
49logger = logging.getLogger(__name__)
50
51
52Address = typing.Tuple[str, int]
53EvLoop = typing.Any
54Socket = socket.socket
55Interceptor = typing.Callable[
56 [EvLoop, Socket, Socket],
57 typing.Coroutine[typing.Any, typing.Any, None],
58]
59
60
61class GateException(Exception):
62 pass
63
64
65class GateInterceptException(Exception):
66 pass
67
68
69async def _yield() -> None:
70 # Minamal delay can be 0. This will be fast path for coroutine switching
71 # https://docs.python.org/3/library/asyncio-task.html#sleeping
72
73 min_delay = 0
74 await asyncio.sleep(min_delay)
75
76
77def _has_data(sock: socket.socket) -> bool:
78 rlist, _, _ = select.select([sock], [], [], 0)
79 return bool(rlist)
80
81
82def _try_get_message(
83 recv_socket: Socket,
84 flags: int,
85) -> typing.Tuple[typing.Optional[bytes], typing.Optional[Address]]:
86 try:
87 return recv_socket.recvfrom(RECV_MAX_SIZE, flags)
88 except (BlockingIOError, InterruptedError):
89 return None, None
90
91
92async def _get_message_task(
93 recv_socket: Socket,
94) -> typing.Tuple[bytes, Address]:
95 while True:
96 msg, addr = _try_get_message(recv_socket, 0)
97 if msg:
98 assert addr
99 return msg, addr
100
101 await _yield()
102
103
104def _incoming_data_size(recv_socket: Socket) -> int:
105 msg, _ = _try_get_message(recv_socket, socket.MSG_PEEK)
106 return len(msg) if msg else 0
107
108
109async def _intercept_ok(
110 loop: EvLoop,
111 socket_from: Socket,
112 socket_to: Socket,
113) -> None:
114 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
115 await loop.sock_sendall(socket_to, data)
116
117
118async def _intercept_noop(
119 loop: EvLoop,
120 socket_from: Socket,
121 socket_to: Socket,
122) -> None:
123 pass
124
125
126async def _intercept_drop(
127 loop: EvLoop,
128 socket_from: Socket,
129 socket_to: Socket,
130) -> None:
131 await loop.sock_recv(socket_from, RECV_MAX_SIZE)
132
133
134async def _intercept_delay(
135 delay: float,
136 loop: EvLoop,
137 socket_from: Socket,
138 socket_to: Socket,
139) -> None:
140 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
141 await asyncio.sleep(delay)
142 await loop.sock_sendall(socket_to, data)
143
144
145async def _intercept_close_on_data(
146 loop: EvLoop,
147 socket_from: Socket,
148 socket_to: Socket,
149) -> None:
150 await loop.sock_recv(socket_from, 1)
151 raise GateInterceptException('Closing socket on data')
152
153
154async def _intercept_corrupt(
155 loop: EvLoop,
156 socket_from: Socket,
157 socket_to: Socket,
158) -> None:
159 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
160 await loop.sock_sendall(socket_to, bytearray([not x for x in data]))
161
162
163class _InterceptBpsLimit:
164 def __init__(self, bytes_per_second: float):
165 assert bytes_per_second >= 1
166 self._bytes_per_second = bytes_per_second
167 self._time_last_added = 0.0
168 self._bytes_left = self._bytes_per_second
169
170 def _update_limit(self) -> None:
171 current_time = time.monotonic()
172 elapsed = current_time - self._time_last_added
173 bytes_addition = self._bytes_per_second * elapsed
174 if bytes_addition > 0:
175 self._bytes_left += bytes_addition
176 self._time_last_added = current_time
177
178 if self._bytes_left > self._bytes_per_second:
179 self._bytes_left = self._bytes_per_second
180
181 async def __call__(
182 self,
183 loop: EvLoop,
184 socket_from: Socket,
185 socket_to: Socket,
186 ) -> None:
187 self._update_limit()
188
189 bytes_to_recv = min(int(self._bytes_left), RECV_MAX_SIZE)
190 if bytes_to_recv > 0:
191 data = await loop.sock_recv(socket_from, bytes_to_recv)
192 if not data:
193 raise RuntimeError('Socket connection was closed by the other side')
194 self._bytes_left -= len(data)
195
196 await loop.sock_sendall(socket_to, data)
197 else:
198 logger.info('Socket hits the bytes per second limit')
199 await asyncio.sleep(1.0 / self._bytes_per_second)
200
201
202class _InterceptTimeLimit:
203 def __init__(self, timeout: float, jitter: float):
204 self._sockets: typing.Dict[Socket, float] = {}
205 assert timeout >= 0.0
206 self._timeout = timeout
207 assert jitter >= 0.0
208 self._jitter = jitter
209
210 def raise_if_timed_out(self, socket_from: Socket) -> None:
211 if socket_from not in self._sockets:
212 jitter = self._jitter * random.random()
213 expire_at = time.monotonic() + self._timeout + jitter
214 self._sockets[socket_from] = expire_at
215
216 if self._sockets[socket_from] <= time.monotonic():
217 del self._sockets[socket_from]
218 raise GateInterceptException('Socket hits the time limit')
219
220 async def __call__(
221 self,
222 loop: EvLoop,
223 socket_from: Socket,
224 socket_to: Socket,
225 ) -> None:
226 self.raise_if_timed_out(socket_from)
227 await _intercept_ok(loop, socket_from, socket_to)
228
229
230class _InterceptSmallerParts:
231 def __init__(self, max_size: int, sleep_per_packet: float):
232 assert max_size > 0
233 self._max_size = max_size
234 self._sleep_per_packet = sleep_per_packet
235
236 async def __call__(
237 self,
238 loop: EvLoop,
239 socket_from: Socket,
240 socket_to: Socket,
241 ) -> None:
242 data = await loop.sock_recv(socket_from, self._max_size)
243 if not data:
244 raise RuntimeError('Socket connection was closed by the other side')
245 await asyncio.sleep(self._sleep_per_packet)
246 await loop.sock_sendall(socket_to, data)
247
248
249class _InterceptConcatPackets:
250 def __init__(self, packet_size: int):
251 assert packet_size >= 0
252 self._packet_size = packet_size
253 self._expire_at: typing.Optional[float] = None
254
255 async def __call__(
256 self,
257 loop: EvLoop,
258 socket_from: Socket,
259 socket_to: Socket,
260 ) -> None:
261 if self._expire_at is None:
262 self._expire_at = time.monotonic() + MAX_DELAY
263
264 if self._expire_at <= time.monotonic():
265 logger.error(
266 f'Failed to make a packet of sufficient size in {MAX_DELAY} '
267 'seconds. Check the test logic, it should end with checking '
268 'that the data was sent and by calling TcpGate function '
269 'to_client_pass() to pass the remaining packets.',
270 )
271 sys.exit(2)
272
273 incoming_size = _incoming_data_size(socket_from)
274 if incoming_size >= self._packet_size:
275 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
276 await loop.sock_sendall(socket_to, data)
277 self._expire_at = None
278
279
280class _InterceptBytesLimit:
281 def __init__(self, bytes_limit: int, gate: 'BaseGate'):
282 assert bytes_limit >= 0
283 self._bytes_limit = bytes_limit
284 self._bytes_remain = self._bytes_limit
285 self._gate = gate
286
287 async def __call__(
288 self,
289 loop: EvLoop,
290 socket_from: Socket,
291 socket_to: Socket,
292 ) -> None:
293 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
294 if not data:
295 raise RuntimeError('Socket connection was closed by the other side')
296 if self._bytes_remain <= len(data):
297 await loop.sock_sendall(socket_to, data[0 : self._bytes_remain])
298 await self._gate.sockets_close()
299 self._bytes_remain = self._bytes_limit
300 raise GateInterceptException('Data transmission limit reached')
301 self._bytes_remain -= len(data)
302 await loop.sock_sendall(socket_to, data)
303
304
305class _InterceptSubstitute:
306 def __init__(self, pattern: str, repl: str, encoding='utf-8'):
307 self._pattern = re.compile(pattern)
308 self._repl = repl
309 self._encoding = encoding
310
311 async def __call__(
312 self,
313 loop: EvLoop,
314 socket_from: Socket,
315 socket_to: Socket,
316 ) -> None:
317 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
318 try:
319 res = self._pattern.sub(self._repl, data.decode(self._encoding))
320 data = res.encode(self._encoding)
321 except UnicodeError:
322 pass
323 await loop.sock_sendall(socket_to, data)
324
325
326async def _cancel_and_join(task: typing.Optional[asyncio.Task]) -> None:
327 if not task or task.cancelled():
328 return
329
330 try:
331 task.cancel()
332 await task
333 except asyncio.CancelledError:
334 return
335 except Exception: # pylint: disable=broad-except
336 logger.exception('Exception in _cancel_and_join')
337
338
339def _make_socket_nonblocking(sock: Socket) -> None:
340 sock.setblocking(False)
341 if sock.type == socket.SOCK_STREAM:
342 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
343 fcntl.fcntl(sock, fcntl.F_SETFL, os.O_NONBLOCK)
344
345
346class _UdpDemuxSocketMock:
347 """
348 Emulates a point-to-point connection over UDP socket
349 with a non-blocking socket interface
350 """
351
352 def __init__(self, sock: Socket, peer_address: Address):
353 self._sock: Socket = sock
354 self._peeraddr: Address = peer_address
355
356 sockpair = socket.socketpair(type=socket.SOCK_DGRAM)
357 self._demux_in: Socket = sockpair[0]
358 self._demux_out: Socket = sockpair[1]
359 _make_socket_nonblocking(self._demux_in)
360 _make_socket_nonblocking(self._demux_out)
361 self._is_active: bool = True
362
363 @property
364 def peer_address(self):
365 return self._peeraddr
366
367 async def push(self, loop: EvLoop, data: bytes):
368 return await loop.sock_sendall(self._demux_in, data)
369
370 def is_active(self):
371 return self._is_active
372
373 def close(self):
374 self._is_active = False
375 self._demux_out.close()
376 self._demux_in.close()
377
378 def recvfrom(self, bufsize: int, flags: int = 0):
379 return self._demux_out.recvfrom(bufsize, flags)
380
381 def recv(self, bufsize: int, flags: int = 0):
382 return self._demux_out.recv(bufsize, flags)
383
384 def fileno(self):
385 return self._demux_out.fileno()
386
387 def send(self, data: bytes):
388 return self._sock.sendto(data, self._peeraddr)
389
390
391class _SocketsPaired:
392 def __init__(
393 self,
394 proxy_name: str,
395 loop: EvLoop,
396 client: typing.Union[socket.socket, _UdpDemuxSocketMock],
397 server: socket.socket,
398 to_server_intercept: Interceptor,
399 to_client_intercept: Interceptor,
400 ) -> None:
401 self._proxy_name = proxy_name
402 self._loop = loop
403
404 self._client = client
405 self._server = server
406
407 self._to_server_intercept: Interceptor = to_server_intercept
408 self._to_client_intercept: Interceptor = to_client_intercept
409
410 self._task_to_server = asyncio.create_task(
411 self._do_pipe_channels(to_server=True),
412 )
413 self._task_to_client = asyncio.create_task(
414 self._do_pipe_channels(to_server=False),
415 )
416
417 async def _do_pipe_channels(self, *, to_server: bool) -> None:
418 if to_server:
419 socket_from = self._client
420 socket_to = self._server
421 else:
422 socket_from = self._server
423 socket_to = self._client
424
425 try:
426 while True:
427 # Applies new interceptors faster.
428 #
429 # To avoid long awaiting on sock_recv in an outdated
430 # interceptor we wait for data before grabbing and applying
431 # the interceptor.
432 if not _has_data(socket_from):
433 await _yield()
434 continue
435
436 if to_server:
437 interceptor = self._to_server_intercept
438 else:
439 interceptor = self._to_client_intercept
440
441 await interceptor(self._loop, socket_from, socket_to)
442 await _yield()
443 except GateInterceptException as exc:
444 logger.info('In "%s": %s', self._proxy_name, exc)
445 except socket.error as exc:
446 logger.error('Exception in "%s": %s', self._proxy_name, exc)
447 finally:
448 # close both sides and cancel tasks
449 self._close_socket(self._client)
450 self._close_socket(self._server)
451 # Closing the sockets here so that the self.shutdown()
452 # returns only when the sockets are actually closed
453 logger.info('"%s" closes %s', self._proxy_name, self.info())
454 if to_server:
455 self._task_to_client.cancel()
456 else:
457 self._task_to_server.cancel()
458
459 def set_to_server_interceptor(self, interceptor: Interceptor) -> None:
460 self._to_server_intercept = interceptor
461
462 def set_to_client_interceptor(self, interceptor: Interceptor) -> None:
463 self._to_client_intercept = interceptor
464
465 def _close_socket(self, self_socket: Socket) -> None:
466 assert self_socket in {self._client, self._server}
467 try:
468 self_socket.close()
469 except socket.error as exc:
470 logger.error(
471 'Exception in "%s" on closing %s: %s',
472 self._proxy_name,
473 'client' if self_socket == self._client else 'server',
474 exc,
475 )
476
477 async def shutdown(self) -> None:
478 for task in {self._task_to_client, self._task_to_server}:
479 await _cancel_and_join(task)
480
481 def is_active(self) -> bool:
482 return not self._task_to_client.done() or not self._task_to_server.done()
483
484 def info(self) -> str:
485 if not self.is_active():
486 return '<inactive>'
487
488 return f'client fd={self._client.fileno()} <=> server fd={self._server.fileno()}'
489
490
491# @endcond
492
493
494class BaseGate:
495 """
496 This base class maintain endpoints of two types:
497
498 Server-side endpoints to receive messages from clients. Address of this
499 endpoint is described by (host_for_client, port_for_client).
500
501 Client-side endpoints to forward messages to server. Server must listen on
502 (host_to_server, port_to_server).
503
504 Asynchronously concurrently passes data from client to server and from
505 server to client, allowing intercepting the data, injecting delays and
506 dropping connections.
507
508 @warning Do not use this class itself. Use one of the specifications
509 TcpGate or UdpGate
510
511 @ingroup userver_testsuite
512
513 @see @ref scripts/docs/en/userver/chaos_testing.md
514 """
515
516 _NOT_IMPLEMENTED_MESSAGE = 'Do not use BaseGate itself, use one of specializations TcpGate or UdpGate'
517
518 def __init__(self, route: GateRoute, loop: typing.Optional[EvLoop] = None) -> None:
519 self._route = route
520 if loop is None:
521 loop = asyncio.get_running_loop()
522 self._loop = loop
523
524 self._to_server_intercept: Interceptor = _intercept_ok
525 self._to_client_intercept: Interceptor = _intercept_ok
526
527 self._accept_sockets: typing.List[socket.socket] = []
528 self._accept_tasks: typing.List[asyncio.Task[None]] = []
529
530 self._sockets: typing.Set[_SocketsPaired] = set()
531
532 async def __aenter__(self) -> 'BaseGate':
533 self.start()
534 return self
535
536 async def __aexit__(self, exc_type, exc_value, traceback) -> None:
537 await self.stop()
538
539 def _create_accepting_sockets(self) -> typing.List[Socket]:
540 raise NotImplementedError(self._NOT_IMPLEMENTED_MESSAGE)
541
542 def start(self):
543 """Open the socket and start accepting tasks"""
544 if self._accept_sockets:
545 return
546
547 self._accept_sockets.extend(self._create_accepting_sockets())
548
549 if not self._accept_sockets:
550 raise GateException(
551 f'Could not resolve hostname {self._route.host_for_client}',
552 )
553
554 if self._route.port_for_client == 0:
555 # In case of stop()+start() bind to the same port
556 self._route = GateRoute(
557 name=self._route.name,
558 host_to_server=self._route.host_to_server,
559 port_to_server=self._route.port_to_server,
560 host_for_client=self._accept_sockets[0].getsockname()[0],
561 port_for_client=self._accept_sockets[0].getsockname()[1],
562 )
563
564 self.start_accepting()
565
566 def start_accepting(self) -> None:
567 """Start accepting tasks"""
568 assert self._accept_sockets
569 if not all(tsk.done() for tsk in self._accept_tasks):
570 return
571
572 self._accept_tasks.clear()
573 for sock in self._accept_sockets:
574 self._accept_tasks.append(
575 asyncio.create_task(self._do_accept(sock)),
576 )
577
578 async def stop_accepting(self) -> None:
579 """
580 Stop accepting tasks without closing the accepting socket.
581 """
582 for tsk in self._accept_tasks:
583 await _cancel_and_join(tsk)
584 self._accept_tasks.clear()
585
586 async def stop(self) -> None:
587 """
588 Stop accepting tasks, close all the sockets
589 """
590 if not self._accept_sockets and not self._sockets:
591 return
592
593 self.to_server_pass()
594 self.to_client_pass()
595
596 await self.stop_accepting()
597 logger.info('Before close() %s', self.info())
598 await self.sockets_close()
599 assert not self._sockets
600
601 for sock in self._accept_sockets:
602 sock.close()
603 self._accept_sockets.clear()
604 logger.info('Stopped. %s', self.info())
605
606 async def sockets_close(
607 self,
608 *,
609 count: typing.Optional[int] = None,
610 ) -> None:
611 """Close all the connection going through the gate"""
612 for x in list(self._sockets)[0:count]:
613 await x.shutdown()
614 self._collect_garbage()
615
616 def get_sockname_for_clients(self) -> Address:
617 """
618 Returns the client socket address that the gate listens on.
619
620 This function allows to use 0 in GateRoute.port_for_client and retrieve
621 the actual port and host.
622 """
623 assert self._route.port_for_client != 0, ('Gate was not started and the port_for_client is still 0',)
624 return (self._route.host_for_client, self._route.port_for_client)
625
626 def info(self) -> str:
627 """Print info on open sockets"""
628 if not self._sockets:
629 return f'"{self._route.name}" no active sockets'
630
631 return f'"{self._route.name}" active sockets:\n\t' + '\n\t'.join(x.info() for x in self._sockets)
632
633 def _collect_garbage(self) -> None:
634 self._sockets = {x for x in self._sockets if x.is_active()}
635
636 async def _do_accept(self, accept_sock: Socket) -> None:
637 """
638 This task should wait for connections and create SocketPair
639 """
640 raise NotImplementedError(self._NOT_IMPLEMENTED_MESSAGE)
641
642 def set_to_server_interceptor(self, interceptor: Interceptor) -> None:
643 """
644 Replace existing interceptor of client to server data with a custom
645 """
646 self._to_server_intercept = interceptor
647 for x in self._sockets:
648 x.set_to_server_interceptor(self._to_server_intercept)
649
650 def set_to_client_interceptor(self, interceptor: Interceptor) -> None:
651 """
652 Replace existing interceptor of server to client data with a custom
653 """
654 self._to_client_intercept = interceptor
655 for x in self._sockets:
656 x.set_to_client_interceptor(self._to_client_intercept)
657
658 def to_server_pass(self) -> None:
659 """Pass data as is"""
660 logging.trace('to_server_pass')
661 self.set_to_server_interceptor(_intercept_ok)
662
663 def to_client_pass(self) -> None:
664 """Pass data as is"""
665 logging.trace('to_client_pass')
666 self.set_to_client_interceptor(_intercept_ok)
667
668 def to_server_noop(self) -> None:
669 """Do not read data, causing client to keep multiple data"""
670 logging.trace('to_server_noop')
671 self.set_to_server_interceptor(_intercept_noop)
672
673 def to_client_noop(self) -> None:
674 """Do not read data, causing server to keep multiple data"""
675 logging.trace('to_client_noop')
676 self.set_to_client_interceptor(_intercept_noop)
677
678 def to_server_drop(self) -> None:
679 """Read and discard data"""
680 logging.trace('to_server_drop')
681 self.set_to_server_interceptor(_intercept_drop)
682
683 def to_client_drop(self) -> None:
684 """Read and discard data"""
685 logging.trace('to_client_drop')
686 self.set_to_client_interceptor(_intercept_drop)
687
688 def to_server_delay(self, delay: float) -> None:
689 """Delay data transmission"""
690 logging.trace('to_server_delay, delay: %s', delay)
691
692 async def _intercept_delay_bound(
693 loop: EvLoop,
694 socket_from: Socket,
695 socket_to: Socket,
696 ) -> None:
697 await _intercept_delay(delay, loop, socket_from, socket_to)
698
699 self.set_to_server_interceptor(_intercept_delay_bound)
700
701 def to_client_delay(self, delay: float) -> None:
702 """Delay data transmission"""
703 logging.trace('to_client_delay, delay: %s', delay)
704
705 async def _intercept_delay_bound(
706 loop: EvLoop,
707 socket_from: Socket,
708 socket_to: Socket,
709 ) -> None:
710 await _intercept_delay(delay, loop, socket_from, socket_to)
711
712 self.set_to_client_interceptor(_intercept_delay_bound)
713
714 def to_server_close_on_data(self) -> None:
715 """Close on first bytes of data from client"""
716 logging.trace('to_server_close_on_data')
717 self.set_to_server_interceptor(_intercept_close_on_data)
718
719 def to_client_close_on_data(self) -> None:
720 """Close on first bytes of data from server"""
721 logging.trace('to_client_close_on_data')
722 self.set_to_client_interceptor(_intercept_close_on_data)
723
724 def to_server_corrupt_data(self) -> None:
725 """Corrupt data received from client"""
726 logging.trace('to_server_corrupt_data')
727 self.set_to_server_interceptor(_intercept_corrupt)
728
729 def to_client_corrupt_data(self) -> None:
730 """Corrupt data received from server"""
731 logging.trace('to_client_corrupt_data')
732 self.set_to_client_interceptor(_intercept_corrupt)
733
734 def to_server_limit_bps(self, bytes_per_second: float) -> None:
735 """Limit bytes per second transmission by network from client"""
736 logging.trace(
737 'to_server_limit_bps, bytes_per_second: %s',
738 bytes_per_second,
739 )
740 self.set_to_server_interceptor(_InterceptBpsLimit(bytes_per_second))
741
742 def to_client_limit_bps(self, bytes_per_second: float) -> None:
743 """Limit bytes per second transmission by network from server"""
744 logging.trace(
745 'to_client_limit_bps, bytes_per_second: %s',
746 bytes_per_second,
747 )
748 self.set_to_client_interceptor(_InterceptBpsLimit(bytes_per_second))
749
750 def to_server_limit_time(self, timeout: float, jitter: float) -> None:
751 """Limit connection lifetime on receive of first bytes from client"""
752 logging.trace(
753 'to_server_limit_time, timeout: %s, jitter: %s',
754 timeout,
755 jitter,
756 )
757 self.set_to_server_interceptor(_InterceptTimeLimit(timeout, jitter))
758
759 def to_client_limit_time(self, timeout: float, jitter: float) -> None:
760 """Limit connection lifetime on receive of first bytes from server"""
761 logging.trace(
762 'to_client_limit_time, timeout: %s, jitter: %s',
763 timeout,
764 jitter,
765 )
766 self.set_to_client_interceptor(_InterceptTimeLimit(timeout, jitter))
767
768 def to_server_smaller_parts(
769 self,
770 max_size: int,
771 *,
772 sleep_per_packet: float = 0,
773 ) -> None:
774 """
775 Pass data to server in smaller parts
776
777 @param max_size Max packet size to send to server
778 @param sleep_per_packet Optional sleep interval per packet, seconds
779 """
780 logging.trace('to_server_smaller_parts, max_size: %s', max_size)
781 self.set_to_server_interceptor(
782 _InterceptSmallerParts(max_size, sleep_per_packet),
783 )
784
785 def to_client_smaller_parts(
786 self,
787 max_size: int,
788 *,
789 sleep_per_packet: float = 0,
790 ) -> None:
791 """
792 Pass data to client in smaller parts
793
794 @param max_size Max packet size to send to client
795 @param sleep_per_packet Optional sleep interval per packet, seconds
796 """
797 logging.trace('to_client_smaller_parts, max_size: %s', max_size)
798 self.set_to_client_interceptor(
799 _InterceptSmallerParts(max_size, sleep_per_packet),
800 )
801
802 def to_server_concat_packets(self, packet_size: int) -> None:
803 """
804 Pass data in bigger parts
805 @param packet_size minimal size of the resulting packet
806 """
807 logging.trace('to_server_concat_packets, packet_size: %s', packet_size)
808 self.set_to_server_interceptor(_InterceptConcatPackets(packet_size))
809
810 def to_client_concat_packets(self, packet_size: int) -> None:
811 """
812 Pass data in bigger parts
813 @param packet_size minimal size of the resulting packet
814 """
815 logging.trace('to_client_concat_packets, packet_size: %s', packet_size)
816 self.set_to_client_interceptor(_InterceptConcatPackets(packet_size))
817
818 def to_server_limit_bytes(self, bytes_limit: int) -> None:
819 """Drop all connections each `bytes_limit` of data sent by network"""
820 logging.trace('to_server_limit_bytes, bytes_limit: %s', bytes_limit)
821 self.set_to_server_interceptor(_InterceptBytesLimit(bytes_limit, self))
822
823 def to_client_limit_bytes(self, bytes_limit: int) -> None:
824 """Drop all connections each `bytes_limit` of data sent by network"""
825 logging.trace('to_client_limit_bytes, bytes_limit: %s', bytes_limit)
826 self.set_to_client_interceptor(_InterceptBytesLimit(bytes_limit, self))
827
828 def to_server_substitute(self, pattern: str, repl: str) -> None:
829 """Apply regex substitution to data from client"""
830 logging.trace(
831 'to_server_substitute, pattern: %s, repl: %s',
832 pattern,
833 repl,
834 )
835 self.set_to_server_interceptor(_InterceptSubstitute(pattern, repl))
836
837 def to_client_substitute(self, pattern: str, repl: str) -> None:
838 """Apply regex substitution to data from server"""
839 logging.trace(
840 'to_client_substitute, pattern: %s, repl: %s',
841 pattern,
842 repl,
843 )
844 self.set_to_client_interceptor(_InterceptSubstitute(pattern, repl))
845
846
847class TcpGate(BaseGate):
848 """
849 Implements TCP chaos-proxy logic such as accepting incoming tcp client
850 connections. On each new connection new tcp client connects to server
851 (host_to_server, port_to_server).
852
853 @ingroup userver_testsuite
854
855 @see @ref scripts/docs/en/userver/chaos_testing.md
856 """
857
858 def __init__(self, route: GateRoute, loop: typing.Optional[EvLoop] = None) -> None:
859 self._connected_event = asyncio.Event()
860 super().__init__(route, loop)
861
862 def connections_count(self) -> int:
863 """
864 Returns maximal amount of connections going through the gate at
865 the moment.
866
867 @warning Some of the connections could be closing, or could be opened
868 right before the function starts. Use with caution!
869 """
870 return len(self._sockets)
871
872 async def wait_for_connections(self, *, count=1, timeout=0.0) -> None:
873 """
874 Wait for at least `count` connections going through the gate.
875
876 @throws asyncio.TimeoutError exception if failed to get the
877 required amount of connections in time.
878 """
879 if timeout <= 0.0:
880 while self.connections_count() < count:
881 await self._connected_event.wait()
882 self._connected_event.clear()
883 return
884
885 deadline = time.monotonic() + timeout
886 while self.connections_count() < count:
887 time_left = deadline - time.monotonic()
888 await asyncio.wait_for(
889 self._connected_event.wait(),
890 timeout=time_left,
891 )
892 self._connected_event.clear()
893
894 def _create_accepting_sockets(self) -> typing.List[Socket]:
895 res: typing.List[Socket] = []
896 for addr in socket.getaddrinfo(
897 self._route.host_for_client,
898 self._route.port_for_client,
899 type=socket.SOCK_STREAM,
900 ):
901 sock = Socket(addr[0], addr[1])
902 _make_socket_nonblocking(sock)
903 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
904 sock.bind(addr[4])
905 sock.listen()
906 logger.debug(
907 f'Accepting connections on {sock.getsockname()}, fd={sock.fileno()}',
908 )
909 res.append(sock)
910
911 return res
912
913 async def _connect_to_server(self):
914 addrs = await self._loop.getaddrinfo(
915 self._route.host_to_server,
916 self._route.port_to_server,
917 type=socket.SOCK_STREAM,
918 )
919 for addr in addrs:
920 server = Socket(addr[0], addr[1])
921 _make_socket_nonblocking(server)
922 try:
923 await self._loop.sock_connect(server, addr[4])
924 logging.trace('Connected to %s', addr[4])
925 return server
926 except Exception as exc: # pylint: disable=broad-except
927 server.close()
928 logging.warning('Could not connect to %s: %s', addr[4], exc)
929
930 async def _do_accept(self, accept_sock: Socket) -> None:
931 while True:
932 client, _ = await self._loop.sock_accept(accept_sock)
933 _make_socket_nonblocking(client)
934
935 server = await self._connect_to_server()
936 if server:
937 self._sockets.add(
938 _SocketsPaired(
939 self._route.name,
940 self._loop,
941 client,
942 server,
943 self._to_server_intercept,
944 self._to_client_intercept,
945 ),
946 )
947 self._connected_event.set()
948 else:
949 client.close()
950
951 self._collect_garbage()
952
953
954class UdpGate(BaseGate):
955 """
956 Implements UDP chaos-proxy logic such as demuxing incoming datagrams
957 from different clients.
958 Separate connections to server are made for each new client.
959
960 @ingroup userver_testsuite
961
962 @see @ref scripts/docs/en/userver/chaos_testing.md
963 """
964
965 def __init__(self, route: GateRoute, loop: typing.Optional[EvLoop] = None):
966 self._clients: typing.Set[_UdpDemuxSocketMock] = set()
967 super().__init__(route, loop)
968
969 def is_connected(self) -> bool:
970 """
971 Returns True if there is active pair of sockets ready to transfer data
972 at the moment.
973 """
974 return len(self._sockets) > 0
975
976 def _create_accepting_sockets(self) -> typing.List[Socket]:
977 res: typing.List[Socket] = []
978 for addr in socket.getaddrinfo(
979 self._route.host_for_client,
980 self._route.port_for_client,
981 type=socket.SOCK_DGRAM,
982 ):
983 sock = socket.socket(addr[0], addr[1])
984 _make_socket_nonblocking(sock)
985 sock.bind(addr[4])
986 logger.debug(f'Accepting connections on {sock.getsockname()}')
987 res.append(sock)
988
989 return res
990
991 async def _connect_to_server(self):
992 addrs = await self._loop.getaddrinfo(
993 self._route.host_to_server,
994 self._route.port_to_server,
995 type=socket.SOCK_DGRAM,
996 )
997 for addr in addrs:
998 server = Socket(addr[0], addr[1])
999 try:
1000 _make_socket_nonblocking(server)
1001 await self._loop.sock_connect(server, addr[4])
1002 logging.trace('Connected to %s', addr[4])
1003 return server
1004 except Exception as exc: # pylint: disable=broad-except
1005 logging.warning('Could not connect to %s: %s', addr[4], exc)
1006
1007 def _collect_garbage(self) -> None:
1008 super()._collect_garbage()
1009 self._clients = {c for c in self._clients if c.is_active()}
1010
1011 async def _do_accept(self, accept_sock: Socket):
1012 while True:
1013 data, addr = await _get_message_task(accept_sock)
1014
1015 client: typing.Optional[_UdpDemuxSocketMock] = None
1016 for known_clients in self._clients:
1017 if addr == known_clients.peer_address:
1018 client = known_clients
1019 break
1020
1021 if client is None:
1022 server = await self._connect_to_server()
1023 if not server:
1024 accept_sock.close()
1025 break
1026
1027 client = _UdpDemuxSocketMock(accept_sock, addr)
1028 self._clients.add(client)
1029
1030 self._sockets.add(
1031 _SocketsPaired(
1032 self._route.name,
1033 self._loop,
1034 client,
1035 server,
1036 self._to_server_intercept,
1037 self._to_client_intercept,
1038 ),
1039 )
1040
1041 await client.push(self._loop, data)
1042 self._collect_garbage()
1043
1044 def to_server_concat_packets(self, packet_size: int) -> None:
1045 raise NotImplementedError('Udp packets cannot be concatenated')
1046
1047 def to_client_concat_packets(self, packet_size: int) -> None:
1048 raise NotImplementedError('Udp packets cannot be concatenated')
1049
1050 def to_server_smaller_parts(
1051 self,
1052 max_size: int,
1053 *,
1054 sleep_per_packet: float = 0,
1055 ) -> None:
1056 raise NotImplementedError('Udp packets cannot be split')
1057
1058 def to_client_smaller_parts(
1059 self,
1060 max_size: int,
1061 *,
1062 sleep_per_packet: float = 0,
1063 ) -> None:
1064 raise NotImplementedError('Udp packets cannot be split')