1
2"""
3Python module that provides testsuite support for
4chaos tests; see
5@ref scripts/docs/en/userver/chaos_testing.md for an introduction.
6
7@ingroup userver_testsuite
8"""
9
10import asyncio
11import dataclasses
12import fcntl
13import logging
14import os
15import random
16import re
17import socket
18import sys
19import time
20import typing
21
22
23@dataclasses.dataclass(frozen=True)
24class GateRoute:
25 """
26 Class that describes the route for TcpGate or UdpGate.
27
28 Use `port_for_client == 0` to bind to some unused port. In that case the
29 actual address could be retrieved via BaseGate.get_sockname_for_clients().
30
31 @ingroup userver_testsuite
32 """
33
34 name: str
35 host_to_server: str
36 port_to_server: int
37 host_for_client: str = '127.0.0.1'
38 port_for_client: int = 0
39
40
41
42
43
44RECV_MAX_SIZE = 4096
45MAX_DELAY = 60.0
46
47
48logger = logging.getLogger(__name__)
49
50
51Address = typing.Tuple[str, int]
52EvLoop = typing.Any
53Socket = socket.socket
54Interceptor = typing.Callable[
55 [EvLoop, Socket, Socket], typing.Coroutine[typing.Any, typing.Any, None],
56]
57
58
59class GateException(Exception):
60 pass
61
62
63class GateInterceptException(Exception):
64 pass
65
66
67async def _yield() -> None:
68
69
70
71 min_delay = 0
72 await asyncio.sleep(min_delay)
73
74
75def _try_get_message(
76 recv_socket: Socket, flags: int,
77) -> typing.Tuple[typing.Optional[bytes], typing.Optional[Address]]:
78 try:
79 return recv_socket.recvfrom(RECV_MAX_SIZE, flags)
80 except (BlockingIOError, InterruptedError):
81 return None, None
82
83
84async def _get_message_task(
85 recv_socket: Socket,
86) -> typing.Tuple[bytes, Address]:
87 while True:
88 msg, addr = _try_get_message(recv_socket, 0)
89 if msg:
90 assert addr
91 return msg, addr
92
93 await _yield()
94
95
96def _incoming_data_size(recv_socket: Socket) -> int:
97 msg, _ = _try_get_message(recv_socket, socket.MSG_PEEK)
98 return len(msg) if msg else 0
99
100
101async def _intercept_ok(
102 loop: EvLoop, socket_from: Socket, socket_to: Socket,
103) -> None:
104 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
105 await loop.sock_sendall(socket_to, data)
106
107
108async def _intercept_noop(
109 loop: EvLoop, socket_from: Socket, socket_to: Socket,
110) -> None:
111 pass
112
113
114async def _intercept_drop(
115 loop: EvLoop, socket_from: Socket, socket_to: Socket,
116) -> None:
117 await loop.sock_recv(socket_from, RECV_MAX_SIZE)
118
119
120async def _intercept_delay(
121 delay: float, loop: EvLoop, socket_from: Socket, socket_to: Socket,
122) -> None:
123 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
124 await asyncio.sleep(delay)
125 await loop.sock_sendall(socket_to, data)
126
127
128async def _intercept_close_on_data(
129 loop: EvLoop, socket_from: Socket, socket_to: Socket,
130) -> None:
131 await loop.sock_recv(socket_from, 1)
132 raise GateInterceptException('Closing socket on data')
133
134
135async def _intercept_corrupt(
136 loop: EvLoop, socket_from: Socket, socket_to: Socket,
137) -> None:
138 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
139 await loop.sock_sendall(socket_to, bytearray([not x for x in data]))
140
141
142class _InterceptBpsLimit:
143 def __init__(self, bytes_per_second: float):
144 assert bytes_per_second >= 1
145 self._bytes_per_second = bytes_per_second
146 self._time_last_added = 0.0
147 self._bytes_left = self._bytes_per_second
148
149 def _update_limit(self) -> None:
150 current_time = time.monotonic()
151 elapsed = current_time - self._time_last_added
152 bytes_addition = self._bytes_per_second * elapsed
153 if bytes_addition > 0:
154 self._bytes_left += bytes_addition
155 self._time_last_added = current_time
156
157 if self._bytes_left > self._bytes_per_second:
158 self._bytes_left = self._bytes_per_second
159
160 async def __call__(
161 self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
162 ) -> None:
163 self._update_limit()
164
165 bytes_to_recv = min(int(self._bytes_left), RECV_MAX_SIZE)
166 if bytes_to_recv > 0:
167 data = await loop.sock_recv(socket_from, bytes_to_recv)
168 self._bytes_left -= len(data)
169
170 await loop.sock_sendall(socket_to, data)
171 else:
172 logger.info('Socket hits the bytes per second limit')
173 await asyncio.sleep(1.0 / self._bytes_per_second)
174
175
176class _InterceptTimeLimit:
177 def __init__(self, timeout: float, jitter: float):
178 self._sockets: typing.Dict[Socket, float] = {}
179 assert timeout >= 0.0
180 self._timeout = timeout
181 assert jitter >= 0.0
182 self._jitter = jitter
183
184 def raise_if_timed_out(self, socket_from: Socket) -> None:
185 if socket_from not in self._sockets:
186 jitter = self._jitter * random.random()
187 expire_at = time.monotonic() + self._timeout + jitter
188 self._sockets[socket_from] = expire_at
189
190 if self._sockets[socket_from] <= time.monotonic():
191 del self._sockets[socket_from]
192 raise GateInterceptException('Socket hits the time limit')
193
194 async def __call__(
195 self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
196 ) -> None:
197 self.raise_if_timed_out(socket_from)
198 await _intercept_ok(loop, socket_from, socket_to)
199
200
201class _InterceptSmallerParts:
202 def __init__(self, max_size: int, sleep_per_packet: float):
203 assert max_size > 0
204 self._max_size = max_size
205 self._sleep_per_packet = sleep_per_packet
206
207 async def __call__(
208 self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
209 ) -> None:
210 incoming_size = _incoming_data_size(socket_from)
211 chunk_size = min(incoming_size, self._max_size)
212 data = await loop.sock_recv(socket_from, chunk_size)
213 await asyncio.sleep(self._sleep_per_packet)
214 await loop.sock_sendall(socket_to, data)
215
216
217class _InterceptConcatPackets:
218 def __init__(self, packet_size: int):
219 assert packet_size >= 0
220 self._packet_size = packet_size
221 self._expire_at: typing.Optional[float] = None
222
223 async def __call__(
224 self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
225 ) -> None:
226 if self._expire_at is None:
227 self._expire_at = time.monotonic() + MAX_DELAY
228
229 if self._expire_at <= time.monotonic():
230 logger.error(
231 f'Failed to make a packet of sufficient size in {MAX_DELAY} '
232 'seconds. Check the test logic, it should end with checking '
233 'that the data was sent and by calling TcpGate function '
234 'to_client_pass() to pass the remaining packets.',
235 )
236 sys.exit(2)
237
238 incoming_size = _incoming_data_size(socket_from)
239 if incoming_size >= self._packet_size:
240 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
241 await loop.sock_sendall(socket_to, data)
242 self._expire_at = None
243
244
245class _InterceptBytesLimit:
246 def __init__(self, bytes_limit: int, gate: 'BaseGate'):
247 assert bytes_limit >= 0
248 self._bytes_limit = bytes_limit
249 self._bytes_remain = self._bytes_limit
250 self._gate = gate
251
252 async def __call__(
253 self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
254 ) -> None:
255 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
256 if self._bytes_remain <= len(data):
257 await loop.sock_sendall(socket_to, data[0 : self._bytes_remain])
258 await self._gate.sockets_close()
259 self._bytes_remain = self._bytes_limit
260 raise GateInterceptException('Data transmission limit reached')
261
262 self._bytes_remain -= len(data)
263 await loop.sock_sendall(socket_to, data)
264
265
266class _InterceptSubstitute:
267 def __init__(self, pattern: str, repl: str, encoding='utf-8'):
268 self._pattern = re.compile(pattern)
269 self._repl = repl
270 self._encoding = encoding
271
272 async def __call__(
273 self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
274 ) -> None:
275 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
276 try:
277 res = self._pattern.sub(self._repl, data.decode(self._encoding))
278 data = res.encode(self._encoding)
279 except UnicodeError:
280 pass
281 await loop.sock_sendall(socket_to, data)
282
283
284async def _cancel_and_join(task: typing.Optional[asyncio.Task]) -> None:
285 if not task or task.cancelled():
286 return
287
288 try:
289 task.cancel()
290 await task
291 except asyncio.CancelledError:
292 return
293 except Exception as exc:
294 logger.error('Exception in _cancel_and_join: %s', exc)
295
296
297def _make_socket_nonblocking(sock: Socket) -> None:
298 sock.setblocking(False)
299 if sock.type == socket.SOCK_STREAM:
300 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
301 fcntl.fcntl(sock, fcntl.F_SETFL, os.O_NONBLOCK)
302
303
304class _UdpDemuxSocketMock:
305 """
306 Emulates a point-to-point connection over UDP socket
307 with a non-blocking socket interface
308 """
309
310 def __init__(self, sock: Socket, peer_address: Address):
311 self._sock: Socket = sock
312 self._peeraddr: Address = peer_address
313
314 sockpair = socket.socketpair(type=socket.SOCK_DGRAM)
315 self._demux_in: Socket = sockpair[0]
316 self._demux_out: Socket = sockpair[1]
317 _make_socket_nonblocking(self._demux_in)
318 _make_socket_nonblocking(self._demux_out)
319 self._is_active: bool = True
320
321 @property
322 def peer_address(self):
323 return self._peeraddr
324
325 async def push(self, loop: EvLoop, data: bytes):
326 return await loop.sock_sendall(self._demux_in, data)
327
328 def is_active(self):
329 return self._is_active
330
331 def close(self):
332 self._is_active = False
333 self._demux_out.close()
334 self._demux_in.close()
335
336 def recvfrom(self, bufsize: int, flags: int = 0):
337 return self._demux_out.recvfrom(bufsize, flags)
338
339 def recv(self, bufsize: int, flags: int = 0):
340 return self._demux_out.recv(bufsize, flags)
341
342 def fileno(self):
343 return self._demux_out.fileno()
344
345 def send(self, data: bytes):
346 return self._sock.sendto(data, self._peeraddr)
347
348
349class _SocketsPaired:
350 def __init__(
351 self,
352 proxy_name: str,
353 loop: EvLoop,
354 client: typing.Union[socket.socket, _UdpDemuxSocketMock],
355 server: socket.socket,
356 to_server_intercept: Interceptor,
357 to_client_intercept: Interceptor,
358 ) -> None:
359 self._proxy_name = proxy_name
360 self._loop = loop
361
362 self._client = client
363 self._server = server
364
365 self._to_server_intercept: Interceptor = to_server_intercept
366 self._to_client_intercept: Interceptor = to_client_intercept
367
368 self._task_to_server = asyncio.create_task(
369 self._do_pipe_channels(to_server=True),
370 )
371 self._task_to_client = asyncio.create_task(
372 self._do_pipe_channels(to_server=False),
373 )
374
375 self._finished_channels = 0
376
377 async def _do_pipe_channels(self, *, to_server: bool) -> None:
378 if to_server:
379 socket_from = self._client
380 socket_to = self._server
381 else:
382 socket_from = self._server
383 socket_to = self._client
384
385 try:
386 while True:
387
388
389
390
391
392 if not _incoming_data_size(socket_from):
393 await _yield()
394 continue
395
396 if to_server:
397 interceptor = self._to_server_intercept
398 else:
399 interceptor = self._to_client_intercept
400
401 await interceptor(self._loop, socket_from, socket_to)
402 await _yield()
403 except GateInterceptException as exc:
404 logger.info('In "%s": %s', self._proxy_name, exc)
405 except socket.error as exc:
406 logger.error('Exception in "%s": %s', self._proxy_name, exc)
407 finally:
408 self._finished_channels += 1
409 if self._finished_channels == 2:
410
411
412 logger.info('"%s" closes %s', self._proxy_name, self.info())
413 self._close_socket(self._client)
414 self._close_socket(self._server)
415 else:
416 assert self._finished_channels == 1
417 if to_server:
418 self._task_to_client.cancel()
419 else:
420 self._task_to_server.cancel()
421
422 def set_to_server_interceptor(self, interceptor: Interceptor) -> None:
423 self._to_server_intercept = interceptor
424
425 def set_to_client_interceptor(self, interceptor: Interceptor) -> None:
426 self._to_client_intercept = interceptor
427
428 def _close_socket(self, self_socket: Socket) -> None:
429 assert self_socket in {self._client, self._server}
430 try:
431 self_socket.close()
432 except socket.error as exc:
433 logger.error(
434 'Exception in "%s" on closing %s: %s',
435 self._proxy_name,
436 'client' if self_socket == self._client else 'server',
437 exc,
438 )
439
440 async def shutdown(self) -> None:
441 for task in {self._task_to_client, self._task_to_server}:
442 await _cancel_and_join(task)
443
444 def is_active(self) -> bool:
445 return (
446 not self._task_to_client.done() or not self._task_to_server.done()
447 )
448
449 def info(self) -> str:
450 if not self.is_active():
451 return '<inactive>'
452
453 return (
454 f'client fd={self._client.fileno()} <=> '
455 f'server fd={self._server.fileno()}'
456 )
457
458
459
460
461
462class BaseGate:
463 """
464 This base class maintain endpoints of two types:
465
466 Server-side endpoints to receive messages from clients. Address of this
467 endpoint is described by (host_for_client, port_for_client).
468
469 Client-side endpoints to forward messages to server. Server must listen on
470 (host_to_server, port_to_server).
471
472 Asynchronously concurrently passes data from client to server and from
473 server to client, allowing intercepting the data, injecting delays and
474 dropping connections.
475
476 @warning Do not use this class itself. Use one of the specifications
477 TcpGate or UdpGate
478
479 @ingroup userver_testsuite
480
481 @see @ref scripts/docs/en/userver/chaos_testing.md
482 """
483
484 _NOT_IMPLEMENTED_MESSAGE = (
485 'Do not use BaseGate itself, use one of '
486 'specializations TcpGate or UdpGate'
487 )
488
489 def __init__(self, route: GateRoute, loop: EvLoop) -> None:
490 self._route = route
491 self._loop = loop
492
493 self._to_server_intercept: Interceptor = _intercept_ok
494 self._to_client_intercept: Interceptor = _intercept_ok
495
496 self._accept_sockets: typing.List[socket.socket] = []
497 self._accept_tasks: typing.List[asyncio.Task[None]] = []
498
499 self._sockets: typing.Set[_SocketsPaired] = set()
500
501 async def __aenter__(self) -> 'BaseGate':
502 self.start()
503 return self
504
505 async def __aexit__(self, exc_type, exc_value, traceback) -> None:
506 await self.stop()
507
508 def _create_accepting_sockets(self) -> typing.List[Socket]:
509 raise NotImplementedError(self._NOT_IMPLEMENTED_MESSAGE)
510
511 def start(self):
512 """Open the socket and start accepting tasks"""
513 if self._accept_sockets:
514 return
515
516 self._accept_sockets.extend(self._create_accepting_sockets())
517
518 if not self._accept_sockets:
519 raise GateException(
520 f'Could not resolve hostname {self._route.host_for_client}',
521 )
522
523 if self._route.port_for_client == 0:
524
525 self._route = GateRoute(
526 name=self._route.name,
527 host_to_server=self._route.host_to_server,
528 port_to_server=self._route.port_to_server,
529 host_for_client=self._accept_sockets[0].getsockname()[0],
530 port_for_client=self._accept_sockets[0].getsockname()[1],
531 )
532
533 BaseGate.start_accepting(self)
534
535 def start_accepting(self) -> None:
536 """Start accepting tasks"""
537 assert self._accept_sockets
538 if not all(tsk.done() for tsk in self._accept_tasks):
539 return
540
541 self._accept_tasks.clear()
542 for sock in self._accept_sockets:
543 self._accept_tasks.append(
544 asyncio.create_task(self._do_accept(sock)),
545 )
546
547 async def stop_accepting(self) -> None:
548 """
549 Stop accepting tasks without closing the accepting socket.
550 """
551 for tsk in self._accept_tasks:
552 await _cancel_and_join(tsk)
553 self._accept_tasks.clear()
554
555 async def stop(self) -> None:
556 """
557 Stop accepting tasks, close all the sockets
558 """
559 if not self._accept_sockets and not self._sockets:
560 return
561
562 self.to_server_pass()
563 self.to_client_pass()
564
565 await BaseGate.stop_accepting(self)
566 logger.info('Before close() %s', self.info())
567 await self.sockets_close()
568 assert not self._sockets
569
570 for sock in self._accept_sockets:
571 sock.close()
572 self._accept_sockets.clear()
573 logger.info('Stopped. %s', self.info())
574
575 async def sockets_close(
576 self, *, count: typing.Optional[int] = None,
577 ) -> None:
578 """Close all the connection going through the gate"""
579 for x in list(self._sockets)[0:count]:
580 await x.shutdown()
581 self._collect_garbage()
582
583 def get_sockname_for_clients(self) -> Address:
584 """
585 Returns the client socket address that the gate listens on.
586
587 This function allows to use 0 in GateRoute.port_for_client and retrieve
588 the actual port and host.
589 """
590 assert self._route.port_for_client != 0, (
591 'Gate was not started and the port_for_client is still 0',
592 )
593 return (self._route.host_for_client, self._route.port_for_client)
594
595 def info(self) -> str:
596 """Print info on open sockets"""
597 if not self._sockets:
598 return f'"{self._route.name}" no active sockets'
599
600 return f'"{self._route.name}" active sockets:\n\t' + '\n\t'.join(
601 x.info() for x in self._sockets
602 )
603
604 def _collect_garbage(self) -> None:
605 self._sockets = {x for x in self._sockets if x.is_active()}
606
607 async def _do_accept(self, accept_sock: Socket) -> None:
608 """
609 This task should wait for connections and create SocketPair
610 """
611 raise NotImplementedError(self._NOT_IMPLEMENTED_MESSAGE)
612
613 def set_to_server_interceptor(self, interceptor: Interceptor) -> None:
614 """
615 Replace existing interceptor of client to server data with a custom
616 """
617 self._to_server_intercept = interceptor
618 for x in self._sockets:
619 x.set_to_server_interceptor(self._to_server_intercept)
620
621 def set_to_client_interceptor(self, interceptor: Interceptor) -> None:
622 """
623 Replace existing interceptor of server to client data with a custom
624 """
625 self._to_client_intercept = interceptor
626 for x in self._sockets:
627 x.set_to_client_interceptor(self._to_client_intercept)
628
629 def to_server_pass(self) -> None:
630 """Pass data as is"""
631 logging.trace('to_server_pass')
632 self.set_to_server_interceptor(_intercept_ok)
633
634 def to_client_pass(self) -> None:
635 """Pass data as is"""
636 logging.trace('to_client_pass')
637 self.set_to_client_interceptor(_intercept_ok)
638
639 def to_server_noop(self) -> None:
640 """Do not read data, causing client to keep multiple data"""
641 logging.trace('to_server_noop')
642 self.set_to_server_interceptor(_intercept_noop)
643
644 def to_client_noop(self) -> None:
645 """Do not read data, causing server to keep multiple data"""
646 logging.trace('to_client_noop')
647 self.set_to_client_interceptor(_intercept_noop)
648
649 def to_server_drop(self) -> None:
650 """Read and discard data"""
651 logging.trace('to_server_drop')
652 self.set_to_server_interceptor(_intercept_drop)
653
654 def to_client_drop(self) -> None:
655 """Read and discard data"""
656 logging.trace('to_client_drop')
657 self.set_to_client_interceptor(_intercept_drop)
658
659 def to_server_delay(self, delay: float) -> None:
660 """Delay data transmission"""
661 logging.trace('to_server_delay, delay: %s', delay)
662
663 async def _intercept_delay_bound(
664 loop: EvLoop, socket_from: Socket, socket_to: Socket,
665 ) -> None:
666 await _intercept_delay(delay, loop, socket_from, socket_to)
667
668 self.set_to_server_interceptor(_intercept_delay_bound)
669
670 def to_client_delay(self, delay: float) -> None:
671 """Delay data transmission"""
672 logging.trace('to_client_delay, delay: %s', delay)
673
674 async def _intercept_delay_bound(
675 loop: EvLoop, socket_from: Socket, socket_to: Socket,
676 ) -> None:
677 await _intercept_delay(delay, loop, socket_from, socket_to)
678
679 self.set_to_client_interceptor(_intercept_delay_bound)
680
681 def to_server_close_on_data(self) -> None:
682 """Close on first bytes of data from client"""
683 logging.trace('to_server_close_on_data')
684 self.set_to_server_interceptor(_intercept_close_on_data)
685
686 def to_client_close_on_data(self) -> None:
687 """Close on first bytes of data from server"""
688 logging.trace('to_client_close_on_data')
689 self.set_to_client_interceptor(_intercept_close_on_data)
690
691 def to_server_corrupt_data(self) -> None:
692 """Corrupt data received from client"""
693 logging.trace('to_server_corrupt_data')
694 self.set_to_server_interceptor(_intercept_corrupt)
695
696 def to_client_corrupt_data(self) -> None:
697 """Corrupt data received from server"""
698 logging.trace('to_client_corrupt_data')
699 self.set_to_client_interceptor(_intercept_corrupt)
700
701 def to_server_limit_bps(self, bytes_per_second: float) -> None:
702 """Limit bytes per second transmission by network from client"""
703 logging.trace(
704 'to_server_limit_bps, bytes_per_second: %s', bytes_per_second,
705 )
706 self.set_to_server_interceptor(_InterceptBpsLimit(bytes_per_second))
707
708 def to_client_limit_bps(self, bytes_per_second: float) -> None:
709 """Limit bytes per second transmission by network from server"""
710 logging.trace(
711 'to_client_limit_bps, bytes_per_second: %s', bytes_per_second,
712 )
713 self.set_to_client_interceptor(_InterceptBpsLimit(bytes_per_second))
714
715 def to_server_limit_time(self, timeout: float, jitter: float) -> None:
716 """Limit connection lifetime on receive of first bytes from client"""
717 logging.trace(
718 'to_server_limit_time, timeout: %s, jitter: %s', timeout, jitter,
719 )
720 self.set_to_server_interceptor(_InterceptTimeLimit(timeout, jitter))
721
722 def to_client_limit_time(self, timeout: float, jitter: float) -> None:
723 """Limit connection lifetime on receive of first bytes from server"""
724 logging.trace(
725 'to_client_limit_time, timeout: %s, jitter: %s', timeout, jitter,
726 )
727 self.set_to_client_interceptor(_InterceptTimeLimit(timeout, jitter))
728
729 def to_server_smaller_parts(
730 self, max_size: int, *, sleep_per_packet: float = 0,
731 ) -> None:
732 """
733 Pass data to server in smaller parts
734
735 @param max_size Max packet size to send to server
736 @param sleep_per_packet Optional sleep interval per packet, seconds
737 """
738 logging.trace('to_server_smaller_parts, max_size: %s', max_size)
739 self.set_to_server_interceptor(
740 _InterceptSmallerParts(max_size, sleep_per_packet),
741 )
742
743 def to_client_smaller_parts(
744 self, max_size: int, *, sleep_per_packet: float = 0,
745 ) -> None:
746 """
747 Pass data to client in smaller parts
748
749 @param max_size Max packet size to send to client
750 @param sleep_per_packet Optional sleep interval per packet, seconds
751 """
752 logging.trace('to_client_smaller_parts, max_size: %s', max_size)
753 self.set_to_client_interceptor(
754 _InterceptSmallerParts(max_size, sleep_per_packet),
755 )
756
757 def to_server_concat_packets(self, packet_size: int) -> None:
758 """
759 Pass data in bigger parts
760 @param packet_size minimal size of the resulting packet
761 """
762 logging.trace('to_server_concat_packets, packet_size: %s', packet_size)
763 self.set_to_server_interceptor(_InterceptConcatPackets(packet_size))
764
765 def to_client_concat_packets(self, packet_size: int) -> None:
766 """
767 Pass data in bigger parts
768 @param packet_size minimal size of the resulting packet
769 """
770 logging.trace('to_client_concat_packets, packet_size: %s', packet_size)
771 self.set_to_client_interceptor(_InterceptConcatPackets(packet_size))
772
773 def to_server_limit_bytes(self, bytes_limit: int) -> None:
774 """Drop all connections each `bytes_limit` of data sent by network"""
775 logging.trace('to_server_limit_bytes, bytes_limit: %s', bytes_limit)
776 self.set_to_server_interceptor(_InterceptBytesLimit(bytes_limit, self))
777
778 def to_client_limit_bytes(self, bytes_limit: int) -> None:
779 """Drop all connections each `bytes_limit` of data sent by network"""
780 logging.trace('to_client_limit_bytes, bytes_limit: %s', bytes_limit)
781 self.set_to_client_interceptor(_InterceptBytesLimit(bytes_limit, self))
782
783 def to_server_substitute(self, pattern: str, repl: str) -> None:
784 """Apply regex substitution to data from client"""
785 logging.trace(
786 'to_server_substitute, pattern: %s, repl: %s', pattern, repl,
787 )
788 self.set_to_server_interceptor(_InterceptSubstitute(pattern, repl))
789
790 def to_client_substitute(self, pattern: str, repl: str) -> None:
791 """Apply regex substitution to data from server"""
792 logging.trace(
793 'to_client_substitute, pattern: %s, repl: %s', pattern, repl,
794 )
795 self.set_to_client_interceptor(_InterceptSubstitute(pattern, repl))
796
797
798class TcpGate(BaseGate):
799 """
800 Implements TCP chaos-proxy logic such as accepting incoming tcp client
801 connections. On each new connection new tcp client connects to server
802 (host_to_server, port_to_server).
803
804 @ingroup userver_testsuite
805
806 @see @ref scripts/docs/en/userver/chaos_testing.md
807 """
808
809 def __init__(self, route: GateRoute, loop: EvLoop) -> None:
810 self._connected_event = asyncio.Event()
811 BaseGate.__init__(self, route, loop)
812
813 def connections_count(self) -> int:
814 """
815 Returns maximal amount of connections going through the gate at
816 the moment.
817
818 @warning Some of the connections could be closing, or could be opened
819 right before the function starts. Use with caution!
820 """
821 return len(self._sockets)
822
823 async def wait_for_connections(self, *, count=1, timeout=0.0) -> None:
824 """
825 Wait for at least `count` connections going through the gate.
826
827 @throws asyncio.TimeoutError exception if failed to get the
828 required amount of connections in time.
829 """
830 if timeout <= 0.0:
831 while self.connections_count() < count:
832 await self._connected_event.wait()
833 self._connected_event.clear()
834 return
835
836 deadline = time.monotonic() + timeout
837 while self.connections_count() < count:
838 time_left = deadline - time.monotonic()
839 await asyncio.wait_for(
840 self._connected_event.wait(), timeout=time_left,
841 )
842 self._connected_event.clear()
843
844 def _create_accepting_sockets(self) -> typing.List[Socket]:
845 res: typing.List[Socket] = []
846 for addr in socket.getaddrinfo(
847 self._route.host_for_client,
848 self._route.port_for_client,
849 type=socket.SOCK_STREAM,
850 ):
851 sock = Socket(addr[0], addr[1])
852 _make_socket_nonblocking(sock)
853 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
854 sock.bind(addr[4])
855 sock.listen()
856 logger.debug(
857 f'Accepting connections on {sock.getsockname()}, '
858 f'fd={sock.fileno()}',
859 )
860 res.append(sock)
861
862 return res
863
864 async def _connect_to_server(self):
865 addrs = await self._loop.getaddrinfo(
866 self._route.host_to_server,
867 self._route.port_to_server,
868 type=socket.SOCK_STREAM,
869 )
870 for addr in addrs:
871 server = Socket(addr[0], addr[1])
872 try:
873 _make_socket_nonblocking(server)
874 await self._loop.sock_connect(server, addr[4])
875 logging.trace('Connected to %s', addr[4])
876 return server
877 except Exception as exc: # pylint: disable=broad-except
878 server.close()
879 logging.warning('Could not connect to %s: %s', addr[4], exc)
880
881 async def _do_accept(self, accept_sock: Socket) -> None:
882 while accept_sock:
883 client, _ = await self._loop.sock_accept(accept_sock)
884 _make_socket_nonblocking(client)
885
886 server = await self._connect_to_server()
887 if server:
888 self._sockets.add(
889 _SocketsPaired(
890 self._route.name,
891 self._loop,
892 client,
893 server,
894 self._to_server_intercept,
895 self._to_client_intercept,
896 ),
897 )
898 self._connected_event.set()
899 else:
900 client.close()
901
902 self._collect_garbage()
903
904
905class UdpGate(BaseGate):
906 """
907 Implements UDP chaos-proxy logic such as demuxing incoming datagrams
908 from different clients.
909 Separate connections to server are made for each new client.
910
911 @ingroup userver_testsuite
912
913 @see @ref scripts/docs/en/userver/chaos_testing.md
914 """
915
916 def __init__(self, route: GateRoute, loop: EvLoop):
917 self._clients: typing.Set[_UdpDemuxSocketMock] = set()
918 BaseGate.__init__(self, route, loop)
919
920 def is_connected(self) -> bool:
921 """
922 Returns True if there is active pair of sockets ready to transfer data
923 at the moment.
924 """
925 return len(self._sockets) > 0
926
927 def _create_accepting_sockets(self) -> typing.List[Socket]:
928 res: typing.List[Socket] = []
929 for addr in socket.getaddrinfo(
930 self._route.host_for_client,
931 self._route.port_for_client,
932 type=socket.SOCK_DGRAM,
933 ):
934 sock = socket.socket(addr[0], addr[1])
935 _make_socket_nonblocking(sock)
936 sock.bind(addr[4])
937 logger.debug(f'Accepting connections on {sock.getsockname()}')
938 res.append(sock)
939
940 return res
941
942 async def _connect_to_server(self):
943 addrs = await self._loop.getaddrinfo(
944 self._route.host_to_server,
945 self._route.port_to_server,
946 type=socket.SOCK_DGRAM,
947 )
948 for addr in addrs:
949 server = Socket(addr[0], addr[1])
950 try:
951 _make_socket_nonblocking(server)
952 await self._loop.sock_connect(server, addr[4])
953 logging.trace('Connected to %s', addr[4])
954 return server
955 except Exception as exc: # pylint: disable=broad-except
956 logging.warning('Could not connect to %s: %s', addr[4], exc)
957
958 def _collect_garbage(self) -> None:
959 super()._collect_garbage()
960 self._clients = {c for c in self._clients if c.is_active()}
961
962 async def _do_accept(self, accept_sock: Socket):
963 while True:
964 data, addr = await _get_message_task(accept_sock)
965
966 client: typing.Optional[_UdpDemuxSocketMock] = None
967 for known_clients in self._clients:
968 if addr == known_clients.peer_address:
969 client = known_clients
970 break
971
972 if client is None:
973 server = await self._connect_to_server()
974 if not server:
975 accept_sock.close()
976 break
977
978 client = _UdpDemuxSocketMock(accept_sock, addr)
979 self._clients.add(client)
980
981 self._sockets.add(
982 _SocketsPaired(
983 self._route.name,
984 self._loop,
985 client,
986 server,
987 self._to_server_intercept,
988 self._to_client_intercept,
989 ),
990 )
991
992 await client.push(self._loop, data)
993 self._collect_garbage()
994
995 def to_server_concat_packets(self, packet_size: int) -> None:
996 raise NotImplementedError('Udp packets cannot be concatenated')
997
998 def to_client_concat_packets(self, packet_size: int) -> None:
999 raise NotImplementedError('Udp packets cannot be concatenated')
1000
1001 def to_server_smaller_parts(
1002 self, max_size: int, *, sleep_per_packet: float = 0,
1003 ) -> None:
1004 raise NotImplementedError('Udp packets cannot be split')
1005
1006 def to_client_smaller_parts(
1007 self, max_size: int, *, sleep_per_packet: float = 0,
1008 ) -> None:
1009 raise NotImplementedError('Udp packets cannot be split')