userver: testsuite/pytest_plugins/pytest_userver/chaos.py
Loading...
Searching...
No Matches
testsuite/pytest_plugins/pytest_userver/chaos.py
1"""
2Python module that provides testsuite support for
3chaos tests; see
4@ref scripts/docs/en/userver/chaos_testing.md for an introduction.
5
6@ingroup userver_testsuite
7"""
8
9import asyncio
10import dataclasses
11import errno
12import fcntl
13import logging
14import os
15import random
16import re
17import socket
18import sys
19import time
20import typing
21
22
23@dataclasses.dataclass(frozen=True)
24class GateRoute:
25 """
26 Class that describes the route for TcpGate or UdpGate.
27
28 Use `port_for_client == 0` to bind to some unused port. In that case the
29 actual address could be retrieved via BaseGate.get_sockname_for_clients().
30
31 @ingroup userver_testsuite
32 """
33
34 name: str
35 host_to_server: str
36 port_to_server: int
37 host_for_client: str = '127.0.0.1'
38 port_for_client: int = 0
39
40
41# @cond
42
43# https://docs.python.org/3/library/socket.html#socket.socket.recv
44RECV_MAX_SIZE = 4096
45MAX_DELAY = 60.0
46
47
48logger = logging.getLogger(__name__)
49
50
51Address = typing.Tuple[str, int]
52EvLoop = typing.Any
53Socket = socket.socket
54Interceptor = typing.Callable[
55 [EvLoop, Socket, Socket], typing.Coroutine[typing.Any, typing.Any, None],
56]
57
58
59class GateException(Exception):
60 pass
61
62
63class GateInterceptException(Exception):
64 pass
65
66
67async def _yield() -> None:
68 # Minamal delay can be 0. This will be fast path for coroutine switching
69 # https://docs.python.org/3/library/asyncio-task.html#sleeping
70
71 _MIN_DELAY = 0
72 await asyncio.sleep(_MIN_DELAY)
73
74
75def _try_get_message(
76 recv_socket: Socket,
77) -> typing.Tuple[typing.Optional[bytes], typing.Optional[Address]]:
78
79 try:
80 return recv_socket.recvfrom(RECV_MAX_SIZE, socket.MSG_PEEK)
81 except socket.error as e:
82 err = e.args[0]
83 if err in {errno.EAGAIN, errno.EWOULDBLOCK}:
84 return None, None
85 raise e
86
87
88async def _wait_for_message_task(
89 recv_socket: Socket,
90) -> typing.Tuple[bytes, Address]:
91 while True:
92 msg, addr = _try_get_message(recv_socket)
93 if msg:
94 assert addr
95 return msg, addr
96
97 await _yield()
98
99
100def _incoming_data_size(recv_socket: Socket) -> int:
101 msg, _ = _try_get_message(recv_socket)
102 return len(msg) if msg else 0
103
104
105async def _intercept_ok(
106 loop: EvLoop, socket_from: Socket, socket_to: Socket,
107) -> None:
108 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
109 await loop.sock_sendall(socket_to, data)
110
111
112async def _intercept_noop(
113 loop: EvLoop, socket_from: Socket, socket_to: Socket,
114) -> None:
115 pass
116
117
118async def _intercept_delay(
119 delay: float, loop: EvLoop, socket_from: Socket, socket_to: Socket,
120) -> None:
121 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
122 await asyncio.sleep(delay)
123 await loop.sock_sendall(socket_to, data)
124
125
126async def _intercept_close_on_data(
127 loop: EvLoop, socket_from: Socket, socket_to: Socket,
128) -> None:
129 await loop.sock_recv(socket_from, 1)
130 raise GateInterceptException('Closing socket on data')
131
132
133async def _intercept_corrupt(
134 loop: EvLoop, socket_from: Socket, socket_to: Socket,
135) -> None:
136 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
137 await loop.sock_sendall(socket_to, bytearray([not x for x in data]))
138
139
140class _InterceptBpsLimit:
141 def __init__(self, bytes_per_second: float):
142 assert bytes_per_second >= 1
143 self._bytes_per_second = bytes_per_second
144 self._time_last_added = 0.0
145 self._bytes_left = self._bytes_per_second
146
147 def _update_limit(self) -> None:
148 current_time = time.monotonic()
149 elapsed = current_time - self._time_last_added
150 bytes_addition = self._bytes_per_second * elapsed
151 if bytes_addition > 0:
152 self._bytes_left += bytes_addition
153 self._time_last_added = current_time
154
155 if self._bytes_left > self._bytes_per_second:
156 self._bytes_left = self._bytes_per_second
157
158 async def __call__(
159 self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
160 ) -> None:
161 self._update_limit()
162
163 bytes_to_recv = min(int(self._bytes_left), RECV_MAX_SIZE)
164 if bytes_to_recv > 0:
165 data = await loop.sock_recv(socket_from, bytes_to_recv)
166 self._bytes_left -= len(data)
167
168 await loop.sock_sendall(socket_to, data)
169 else:
170 logger.info('Socket hits the bytes per second limit')
171 await asyncio.sleep(1.0 / self._bytes_per_second)
172
173
174class _InterceptTimeLimit:
175 def __init__(self, timeout: float, jitter: float):
176 self._sockets: typing.Dict[Socket, float] = {}
177 assert timeout >= 0.0
178 self._timeout = timeout
179 assert jitter >= 0.0
180 self._jitter = jitter
181
182 def raise_if_timed_out(self, socket_from: Socket) -> None:
183 if socket_from not in self._sockets:
184 jitter = self._jitter * random.random()
185 expire_at = time.monotonic() + self._timeout + jitter
186 self._sockets[socket_from] = expire_at
187
188 if self._sockets[socket_from] <= time.monotonic():
189 del self._sockets[socket_from]
190 raise GateInterceptException('Socket hits the time limit')
191
192 async def __call__(
193 self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
194 ) -> None:
195 self.raise_if_timed_out(socket_from)
196 await _intercept_ok(loop, socket_from, socket_to)
197
198
199class _InterceptSmallerParts:
200 def __init__(self, max_size: int, sleep_per_packet: float):
201 assert max_size > 0
202 self._max_size = max_size
203 self._sleep_per_packet = sleep_per_packet
204
205 async def __call__(
206 self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
207 ) -> None:
208 incoming_size = _incoming_data_size(socket_from)
209 chunk_size = min(incoming_size, self._max_size)
210 data = await loop.sock_recv(socket_from, chunk_size)
211 await asyncio.sleep(self._sleep_per_packet)
212 await loop.sock_sendall(socket_to, data)
213
214
215class _InterceptConcatPackets:
216 def __init__(self, packet_size: int):
217 assert packet_size >= 0
218 self._packet_size = packet_size
219 self._expire_at: typing.Optional[float] = None
220
221 async def __call__(
222 self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
223 ) -> None:
224 if self._expire_at is None:
225 self._expire_at = time.monotonic() + MAX_DELAY
226
227 if self._expire_at <= time.monotonic():
228 logger.error(
229 f'Failed to make a packet of sufficient size in {MAX_DELAY} '
230 'seconds. Check the test logic, it should end with checking '
231 'that the data was sent and by calling TcpGate function '
232 'to_client_pass() to pass the remaining packets.',
233 )
234 sys.exit(2)
235
236 incoming_size = _incoming_data_size(socket_from)
237 if incoming_size >= self._packet_size:
238 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
239 await loop.sock_sendall(socket_to, data)
240 self._expire_at = None
241
242
243class _InterceptBytesLimit:
244 def __init__(self, bytes_limit: int, gate: 'BaseGate'):
245 assert bytes_limit >= 0
246 self._bytes_limit = bytes_limit
247 self._bytes_remain = self._bytes_limit
248 self._gate = gate
249
250 async def __call__(
251 self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
252 ) -> None:
253 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
254 if self._bytes_remain <= len(data):
255 await loop.sock_sendall(socket_to, data[0 : self._bytes_remain])
256 await self._gate.sockets_close()
257 self._bytes_remain = self._bytes_limit
258 raise GateInterceptException('Data transmission limit reached')
259
260 self._bytes_remain -= len(data)
261 await loop.sock_sendall(socket_to, data)
262
263
264class _InterceptSubstitute:
265 def __init__(self, pattern: str, repl: str, encoding='utf-8'):
266 self._pattern = re.compile(pattern)
267 self._repl = repl
268 self._encoding = encoding
269
270 async def __call__(
271 self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
272 ) -> None:
273 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
274 try:
275 res = self._pattern.sub(self._repl, data.decode(self._encoding))
276 data = res.encode(self._encoding)
277 except UnicodeError:
278 pass
279 await loop.sock_sendall(socket_to, data)
280
281
282async def _cancel_and_join(task: typing.Optional[asyncio.Task]) -> None:
283 if not task or task.cancelled():
284 return
285
286 try:
287 task.cancel()
288 await task
289 except asyncio.CancelledError:
290 return
291 except Exception as exc: # pylint: disable=broad-except
292 logger.error('Exception in _cancel_and_join: %s', exc)
293
294
295class _SocketsPaired:
296 def __init__(
297 self,
298 proxy_name: str,
299 loop: EvLoop,
300 client: socket.socket,
301 server: socket.socket,
302 to_server_intercept: Interceptor,
303 to_client_intercept: Interceptor,
304 ) -> None:
305 self._proxy_name = proxy_name
306 self._loop = loop
307
308 self._client = client
309 self._server = server
310
311 self._to_server_intercept: Interceptor = to_server_intercept
312 self._to_client_intercept: Interceptor = to_client_intercept
313
314 self._task_to_server = asyncio.create_task(
315 self._do_pipe_channels(to_server=True),
316 )
317 self._task_to_client = asyncio.create_task(
318 self._do_pipe_channels(to_server=False),
319 )
320
321 self._finished_channels = 0
322
323 async def _do_pipe_channels(self, *, to_server: bool) -> None:
324 if to_server:
325 socket_from = self._client
326 socket_to = self._server
327 else:
328 socket_from = self._server
329 socket_to = self._client
330
331 try:
332 while True:
333 # Applies new interceptors faster.
334 #
335 # To avoid long awaiting on sock_recv in an outdated
336 # interceptor we wait for data before grabbing and applying
337 # the interceptor.
338 if not _incoming_data_size(socket_from):
339 await _yield()
340 continue
341
342 if to_server:
343 interceptor = self._to_server_intercept
344 else:
345 interceptor = self._to_client_intercept
346
347 await interceptor(self._loop, socket_from, socket_to)
348 await _yield()
349 except GateInterceptException as exc:
350 logger.info('In "%s": %s', self._proxy_name, exc)
351 except socket.error as exc:
352 logger.error('Exception in "%s": %s', self._proxy_name, exc)
353 finally:
354 self._finished_channels += 1
355 if self._finished_channels == 2:
356 # Closing the sockets here so that the self.shutdown()
357 # returns only when the sockets are actually closed
358 logger.info('"%s" closes %s', self._proxy_name, self.info())
359 self._close_socket(self._client)
360 self._close_socket(self._server)
361 else:
362 assert self._finished_channels == 1
363 if to_server:
364 self._task_to_client.cancel()
365 else:
366 self._task_to_server.cancel()
367
368 def set_to_server_interceptor(self, interceptor: Interceptor) -> None:
369 self._to_server_intercept = interceptor
370
371 def set_to_client_interceptor(self, interceptor: Interceptor) -> None:
372 self._to_client_intercept = interceptor
373
374 def _close_socket(self, self_socket: Socket) -> None:
375 assert self_socket in {self._client, self._server}
376 try:
377 self_socket.close()
378 except socket.error as exc:
379 logger.error(
380 'Exception in "%s" on closing %s: %s',
381 self._proxy_name,
382 'client' if self_socket == self._client else 'server',
383 exc,
384 )
385
386 async def shutdown(self) -> None:
387 for task in {self._task_to_client, self._task_to_server}:
388 await _cancel_and_join(task)
389
390 def is_active(self) -> bool:
391 return (
392 not self._task_to_client.done() or not self._task_to_server.done()
393 )
394
395 def info(self) -> str:
396 if not self.is_active():
397 return '<inactive>'
398
399 return (
400 f'client fd={self._client.fileno()} <=> '
401 f'server fd={self._server.fileno()}'
402 )
403
404
405# @endcond
406
407
408class BaseGate:
409 """
410 This base class maintain endpoints of two types:
411
412 Server-side endpoints to receive messages from clients. Address of this
413 endpoint is described by (host_for_client, port_for_client).
414
415 Client-side endpoints to forward messages to server. Server must listen on
416 (host_to_server, port_to_server).
417
418 Asynchronously concurrently passes data from client to server and from
419 server to client, allowing intercepting the data, injecting delays and
420 dropping connections.
421
422 @warning Do not use this class itself. Use one of the specifications
423 TcpGate or UdpGate
424
425 @ingroup userver_testsuite
426
427 @see @ref scripts/docs/en/userver/chaos_testing.md
428 """
429
430 _NOT_IMPLEMENTED_MESSAGE = (
431 'Do not use BaseGate itself, use one of '
432 'specializations TcpGate or UdpGate'
433 )
434
435 def __init__(self, route: GateRoute, loop: EvLoop) -> None:
436 self._route = route
437 self._loop = loop
438
439 self._to_server_intercept: Interceptor = _intercept_ok
440 self._to_client_intercept: Interceptor = _intercept_ok
441
442 self._accept_sockets: typing.List[socket.socket] = []
443 self._accept_tasks: typing.List[asyncio.Task[None]] = []
444
445 self._connected_event = asyncio.Event()
446
447 self._sockets: typing.Set[_SocketsPaired] = set()
448
449 async def __aenter__(self) -> 'BaseGate':
450 self.start()
451 return self
452
453 async def __aexit__(self, exc_type, exc_value, traceback) -> None:
454 await self.stop()
455
456 def _create_accepting_sockets(self) -> typing.List[Socket]:
457 raise NotImplementedError(self._NOT_IMPLEMENTED_MESSAGE)
458
459 def start(self):
460 """ Open the socket and start accepting tasks """
461 if self._accept_sockets:
462 return
463
464 self._accept_sockets.extend(self._create_accepting_sockets())
465
466 if not self._accept_sockets:
467 raise GateException(
468 f'Could not resolve hostname {self._route.host_for_client}',
469 )
470
471 if self._route.port_for_client == 0:
472 # In case of stop()+start() bind to the same port
473 self._route = GateRoute(
474 name=self._route.name,
475 host_to_server=self._route.host_to_server,
476 port_to_server=self._route.port_to_server,
477 host_for_client=self._accept_sockets[0].getsockname()[0],
478 port_for_client=self._accept_sockets[0].getsockname()[1],
479 )
480
481 BaseGate.start_accepting(self)
482
483 def start_accepting(self) -> None:
484 """ Start accepting tasks """
485 assert self._accept_sockets
486 if not all(tsk.done() for tsk in self._accept_tasks):
487 return
488
489 self._accept_tasks.clear()
490 for sock in self._accept_sockets:
491 self._accept_tasks.append(
492 asyncio.create_task(self._do_accept(sock)),
493 )
494
495 async def stop_accepting(self) -> None:
496 """
497 Stop accepting tasks without closing the accepting socket.
498 """
499 for tsk in self._accept_tasks:
500 await _cancel_and_join(tsk)
501 self._accept_tasks.clear()
502
503 async def stop(self) -> None:
504 """
505 Stop accepting tasks, close all the sockets
506 """
507 if not self._accept_sockets and not self._sockets:
508 return
509
510 self.to_server_pass()
511 self.to_client_pass()
512
513 for sock in self._accept_sockets:
514 sock.close()
515
516 await BaseGate.stop_accepting(self)
517 logger.info('Before close() %s', self.info())
518 await self.sockets_close()
519 assert not self._sockets
520
521 self._accept_sockets.clear()
522 logger.info('Stopped. %s', self.info())
523
524 async def sockets_close(
525 self, *, count: typing.Optional[int] = None,
526 ) -> None:
527 """ Close all the connection going through the gate """
528 for x in list(self._sockets)[0:count]:
529 await x.shutdown()
530 self._collect_garbage()
531
532 def get_sockname_for_clients(self) -> Address:
533 """
534 Returns the client socket address that the gate listens on.
535
536 This function allows to use 0 in GateRoute.port_for_client and retrieve
537 the actual port and host.
538 """
539 assert self._route.port_for_client != 0, (
540 'Gate was not started and the port_for_client is still 0',
541 )
542 return (self._route.host_for_client, self._route.port_for_client)
543
544 def info(self) -> str:
545 """ Print info on open sockets """
546 if not self._sockets:
547 return f'"{self._route.name}" no active sockets'
548
549 return f'"{self._route.name}" active sockets:\n\t' + '\n\t'.join(
550 x.info() for x in self._sockets
551 )
552
553 def _collect_garbage(self) -> None:
554 self._sockets = {x for x in self._sockets if x.is_active()}
555
556 async def _do_accept(self, accept_sock: Socket) -> None:
557 """
558 This task should wait for connections and create SocketPair
559 """
560 raise NotImplementedError(self._NOT_IMPLEMENTED_MESSAGE)
561
562 def set_to_server_interceptor(self, interceptor: Interceptor) -> None:
563 """
564 Replace existing interceptor of client to server data with a custom
565 """
566 self._to_server_intercept = interceptor
567 for x in self._sockets:
568 x.set_to_server_interceptor(self._to_server_intercept)
569
570 def set_to_client_interceptor(self, interceptor: Interceptor) -> None:
571 """
572 Replace existing interceptor of server to client data with a custom
573 """
574 self._to_client_intercept = interceptor
575 for x in self._sockets:
576 x.set_to_client_interceptor(self._to_client_intercept)
577
578 def to_server_pass(self) -> None:
579 """ Pass data as is """
580 logging.debug('to_server_pass')
581 self.set_to_server_interceptor(_intercept_ok)
582
583 def to_client_pass(self) -> None:
584 """ Pass data as is """
585 logging.debug('to_client_pass')
586 self.set_to_client_interceptor(_intercept_ok)
587
588 def to_server_noop(self) -> None:
589 """ Do not read data, causing client to keep multiple data """
590 logging.debug('to_server_noop')
591 self.set_to_server_interceptor(_intercept_noop)
592
593 def to_client_noop(self) -> None:
594 """ Do not read data, causing server to keep multiple data """
595 logging.debug('to_client_noop')
596 self.set_to_client_interceptor(_intercept_noop)
597
598 def to_server_delay(self, delay: float) -> None:
599 """ Delay data transmission """
600 logging.debug('to_server_delay, delay: %s', delay)
601
602 async def _intercept_delay_bound(
603 loop: EvLoop, socket_from: Socket, socket_to: Socket,
604 ) -> None:
605 await _intercept_delay(delay, loop, socket_from, socket_to)
606
607 self.set_to_server_interceptor(_intercept_delay_bound)
608
609 def to_client_delay(self, delay: float) -> None:
610 """ Delay data transmission """
611 logging.debug('to_client_delay, delay: %s', delay)
612
613 async def _intercept_delay_bound(
614 loop: EvLoop, socket_from: Socket, socket_to: Socket,
615 ) -> None:
616 await _intercept_delay(delay, loop, socket_from, socket_to)
617
618 self.set_to_client_interceptor(_intercept_delay_bound)
619
620 def to_server_close_on_data(self) -> None:
621 """ Close on first bytes of data from client """
622 logging.debug('to_server_close_on_data')
623 self.set_to_server_interceptor(_intercept_close_on_data)
624
625 def to_client_close_on_data(self) -> None:
626 """ Close on first bytes of data from server """
627 logging.debug('to_client_close_on_data')
628 self.set_to_client_interceptor(_intercept_close_on_data)
629
630 def to_server_corrupt_data(self) -> None:
631 """ Corrupt data received from client """
632 logging.debug('to_server_corrupt_data')
633 self.set_to_server_interceptor(_intercept_corrupt)
634
635 def to_client_corrupt_data(self) -> None:
636 """ Corrupt data received from server """
637 logging.debug('to_client_corrupt_data')
638 self.set_to_client_interceptor(_intercept_corrupt)
639
640 def to_server_limit_bps(self, bytes_per_second: float) -> None:
641 """ Limit bytes per second transmission by network from client """
642 logging.debug(
643 'to_server_limit_bps, bytes_per_second: %s', bytes_per_second,
644 )
645 self.set_to_server_interceptor(_InterceptBpsLimit(bytes_per_second))
646
647 def to_client_limit_bps(self, bytes_per_second: float) -> None:
648 """ Limit bytes per second transmission by network from server """
649 logging.debug(
650 'to_client_limit_bps, bytes_per_second: %s', bytes_per_second,
651 )
652 self.set_to_client_interceptor(_InterceptBpsLimit(bytes_per_second))
653
654 def to_server_limit_time(self, timeout: float, jitter: float) -> None:
655 """ Limit connection lifetime on receive of first bytes from client """
656 logging.debug(
657 'to_server_limit_time, timeout: %s, jitter: %s', timeout, jitter,
658 )
659 self.set_to_server_interceptor(_InterceptTimeLimit(timeout, jitter))
660
661 def to_client_limit_time(self, timeout: float, jitter: float) -> None:
662 """ Limit connection lifetime on receive of first bytes from server """
663 logging.debug(
664 'to_client_limit_time, timeout: %s, jitter: %s', timeout, jitter,
665 )
666 self.set_to_client_interceptor(_InterceptTimeLimit(timeout, jitter))
667
668 def to_server_smaller_parts(
669 self, max_size: int, *, sleep_per_packet: float = 0,
670 ) -> None:
671 """
672 Pass data to server in smaller parts
673
674 @param max_size Max packet size to send to server
675 @param sleep_per_packet Optional sleep interval per packet, seconds
676 """
677 logging.debug('to_server_smaller_parts, max_size: %s', max_size)
678 self.set_to_server_interceptor(
679 _InterceptSmallerParts(max_size, sleep_per_packet),
680 )
681
682 def to_client_smaller_parts(
683 self, max_size: int, *, sleep_per_packet: float = 0,
684 ) -> None:
685 """
686 Pass data to client in smaller parts
687
688 @param max_size Max packet size to send to client
689 @param sleep_per_packet Optional sleep interval per packet, seconds
690 """
691 logging.debug('to_client_smaller_parts, max_size: %s', max_size)
692 self.set_to_client_interceptor(
693 _InterceptSmallerParts(max_size, sleep_per_packet),
694 )
695
696 def to_server_concat_packets(self, packet_size: int) -> None:
697 """
698 Pass data in bigger parts
699 @param packet_size minimal size of the resulting packet
700 """
701 logging.debug('to_server_concat_packets, packet_size: %s', packet_size)
702 self.set_to_server_interceptor(_InterceptConcatPackets(packet_size))
703
704 def to_client_concat_packets(self, packet_size: int) -> None:
705 """
706 Pass data in bigger parts
707 @param packet_size minimal size of the resulting packet
708 """
709 logging.debug('to_client_concat_packets, packet_size: %s', packet_size)
710 self.set_to_client_interceptor(_InterceptConcatPackets(packet_size))
711
712 def to_server_limit_bytes(self, bytes_limit: int) -> None:
713 """ Drop all connections each `bytes_limit` of data sent by network """
714 logging.debug('to_server_limit_bytes, bytes_limit: %s', bytes_limit)
715 self.set_to_server_interceptor(_InterceptBytesLimit(bytes_limit, self))
716
717 def to_client_limit_bytes(self, bytes_limit: int) -> None:
718 """ Drop all connections each `bytes_limit` of data sent by network """
719 logging.debug('to_client_limit_bytes, bytes_limit: %s', bytes_limit)
720 self.set_to_client_interceptor(_InterceptBytesLimit(bytes_limit, self))
721
722 def to_server_substitute(self, pattern: str, repl: str) -> None:
723 """ Apply regex substitution to data from client """
724 logging.debug(
725 'to_server_substitute, pattern: %s, repl: %s', pattern, repl,
726 )
727 self.set_to_server_interceptor(_InterceptSubstitute(pattern, repl))
728
729 def to_client_substitute(self, pattern: str, repl: str) -> None:
730 """ Apply regex substitution to data from server """
731 logging.debug(
732 'to_client_substitute, pattern: %s, repl: %s', pattern, repl,
733 )
734 self.set_to_client_interceptor(_InterceptSubstitute(pattern, repl))
735
736
737class TcpGate(BaseGate):
738 """
739 Implements TCP chaos-proxy logic such as accepting incoming tcp client
740 connections. On each new connection new tcp client connects to server
741 (host_to_server, port_to_server).
742
743 @ingroup userver_testsuite
744
745 @see @ref scripts/docs/en/userver/chaos_testing.md
746 """
747
748 def __init__(self, route: GateRoute, loop: EvLoop) -> None:
749 BaseGate.__init__(self, route, loop)
750
751 def connections_count(self) -> int:
752 """
753 Returns maximal amount of connections going through the gate at
754 the moment.
755
756 @warning Some of the connections could be closing, or could be opened
757 right before the function starts. Use with caution!
758 """
759 return len(self._sockets)
760
761 async def wait_for_connections(self, *, count=1, timeout=0.0) -> None:
762 """
763 Wait for at least `count` connections going through the gate.
764
765 @throws asyncio.TimeoutError exception if failed to get the
766 required amount of connections in time.
767 """
768 if timeout <= 0.0:
769 while self.connections_count() < count:
770 await self._connected_event.wait()
771 self._connected_event.clear()
772 return
773
774 deadline = time.monotonic() + timeout
775 while self.connections_count() < count:
776 time_left = deadline - time.monotonic()
777 await asyncio.wait_for(
778 self._connected_event.wait(), timeout=time_left,
779 )
780 self._connected_event.clear()
781
782 def _make_socket_nonblocking(self, sock: Socket) -> None:
783 sock.setblocking(False)
784 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
785 fcntl.fcntl(sock, fcntl.F_SETFL, os.O_NONBLOCK)
786
787 def _create_accepting_sockets(self) -> typing.List[Socket]:
788 res: typing.List[Socket] = []
789 for addr in socket.getaddrinfo(
790 self._route.host_for_client,
791 self._route.port_for_client,
792 type=socket.SOCK_STREAM,
793 ):
794 sock = Socket(addr[0], addr[1])
795 self._make_socket_nonblocking(sock)
796 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
797 sock.bind(addr[4])
798 sock.listen()
799 logger.debug(
800 f'Accepting connections on {sock.getsockname()}, '
801 f'fd={sock.fileno()}',
802 )
803 res.append(sock)
804
805 return res
806
807 async def _connect_to_server(self):
808 addrs = await self._loop.getaddrinfo(
809 self._route.host_to_server,
810 self._route.port_to_server,
811 type=socket.SOCK_STREAM,
812 )
813 for addr in addrs:
814 server = Socket(addr[0], addr[1])
815 try:
816 self._make_socket_nonblocking(server)
817 await self._loop.sock_connect(server, addr[4])
818 logging.debug('Connected to %s', addr[4])
819 return server
820 except Exception as exc: # pylint: disable=broad-except
821 server.close()
822 logging.warning('Could not connect to %s: %s', addr[4], exc)
823
824 async def _do_accept(self, accept_sock: Socket) -> None:
825 while accept_sock:
826 client, _ = await self._loop.sock_accept(accept_sock)
827 self._make_socket_nonblocking(client)
828
829 server = await self._connect_to_server()
830 if server:
831 self._sockets.add(
832 _SocketsPaired(
833 self._route.name,
834 self._loop,
835 client,
836 server,
837 self._to_server_intercept,
838 self._to_client_intercept,
839 ),
840 )
841 self._connected_event.set()
842 else:
843 client.close()
844
845 self._collect_garbage()
846
847
848class UdpGate(BaseGate):
849 """
850 Implements UDP chaos-proxy logic such as waiting for first
851 message and setting up sockets for forwarding messages between
852 udp-client and udp-server
853
854 @ingroup userver_testsuite
855
856 @see @ref scripts/docs/en/userver/chaos_testing.md
857 """
858
859 _NOT_IMPLEMENTED_IN_UDP = 'This method is not allowed in UDP gate'
860
861 def __init__(self, route: GateRoute, loop: EvLoop):
862 self._client_addr: typing.Optional[Address] = None
863 BaseGate.__init__(self, route, loop)
864
865 def is_connected(self) -> bool:
866 """
867 Returns True if there is active pair of sockets ready to transfer data
868 at the moment.
869 """
870 return len(self._sockets) > 0
871
872 def _make_socket_nonblocking(self, sock: Socket) -> None:
873 sock.setblocking(False)
874 fcntl.fcntl(sock, fcntl.F_SETFL, os.O_NONBLOCK)
875
876 def _create_accepting_sockets(self) -> typing.List[Socket]:
877 res: typing.List[Socket] = []
878 for addr in socket.getaddrinfo(
879 self._route.host_for_client,
880 self._route.port_for_client,
881 type=socket.SOCK_DGRAM,
882 ):
883 sock = socket.socket(addr[0], addr[1])
884 self._make_socket_nonblocking(sock)
885 sock.bind(addr[4])
886 logger.debug(f'Accepting connections on {sock.getsockname()}')
887 res.append(sock)
888
889 return res
890
891 async def _connect_to_server(self):
892 addrs = await self._loop.getaddrinfo(
893 self._route.host_to_server,
894 self._route.port_to_server,
895 type=socket.SOCK_DGRAM,
896 )
897 for addr in addrs:
898 server = Socket(addr[0], addr[1])
899 try:
900 self._make_socket_nonblocking(server)
901 await self._loop.sock_connect(server, addr[4])
902 logging.debug('Connected to %s', addr[4])
903 return server
904 except Exception as exc: # pylint: disable=broad-except
905 logging.warning('Could not connect to %s: %s', addr[4], exc)
906
907 async def _do_accept(self, accept_sock: Socket):
908 if not accept_sock:
909 return
910
911 _, addr = await _wait_for_message_task(accept_sock)
912
913 self._client_addr = addr
914 try:
915 await self._loop.sock_connect(accept_sock, addr)
916 except Exception as exc: # pylint: disable=broad-except
917 logging.warning('Could not connect to %s: %s', addr, exc)
918
919 server = await self._connect_to_server()
920 if server:
921 self._sockets.add(
922 _SocketsPaired(
923 self._route.name,
924 self._loop,
925 accept_sock,
926 server,
927 self._to_server_intercept,
928 self._to_client_intercept,
929 ),
930 )
931 self._connected_event.set()
932 else:
933 accept_sock.close()
934
935 self._collect_garbage()
936
937 def start_accepting(self) -> None:
938 raise NotImplementedError(
939 'Since UdpGate can only have one connection, you cannot start or '
940 'stop accepting tasks manually. Use start() and stop() methods to '
941 'stop data transferring',
942 )
943
944 async def stop_accepting(self) -> None:
945 raise NotImplementedError(
946 'Since UdpGate can only have one connection, you cannot start or '
947 'stop accepting tasks manually. Use start() and stop() methods to '
948 'stop data transferring',
949 )
950
951 def to_server_concat_packets(self, packet_size: int) -> None:
952 raise NotImplementedError('Udp packets cannot be concatenated')
953
954 def to_client_concat_packets(self, packet_size: int) -> None:
955 raise NotImplementedError('Udp packets cannot be concatenated')
956
957 def to_server_smaller_parts(
958 self, max_size: int, *, sleep_per_packet: float = 0,
959 ) -> None:
960 raise NotImplementedError('Udp packets cannot be split')
961
962 def to_client_smaller_parts(
963 self, max_size: int, *, sleep_per_packet: float = 0,
964 ) -> None:
965 raise NotImplementedError('Udp packets cannot be split')