26 Class that describes the route for TcpGate or UdpGate.
28 Use `port_for_client == 0` to bind to some unused port. In that case the
29 actual address could be retrieved via BaseGate.get_sockname_for_clients().
31 @ingroup userver_testsuite
37 host_for_client: str =
'127.0.0.1'
38 port_for_client: int = 0
48logger = logging.getLogger(__name__)
51Address = typing.Tuple[str, int]
54Interceptor = typing.Callable[
55 [EvLoop, Socket, Socket], typing.Coroutine[typing.Any, typing.Any,
None],
59class GateException(Exception):
63class GateInterceptException(Exception):
67async def _yield() -> None:
72 await asyncio.sleep(min_delay)
77) -> typing.Tuple[typing.Optional[bytes], typing.Optional[Address]]:
80 return recv_socket.recvfrom(RECV_MAX_SIZE, socket.MSG_PEEK)
81 except socket.error
as exc:
83 if err
in {errno.EAGAIN, errno.EWOULDBLOCK}:
88async def _wait_for_message_task(
90) -> typing.Tuple[bytes, Address]:
92 msg, addr = _try_get_message(recv_socket)
100def _incoming_data_size(recv_socket: Socket) -> int:
101 msg, _ = _try_get_message(recv_socket)
102 return len(msg)
if msg
else 0
105async def _intercept_ok(
106 loop: EvLoop, socket_from: Socket, socket_to: Socket,
108 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
109 await loop.sock_sendall(socket_to, data)
112async def _intercept_noop(
113 loop: EvLoop, socket_from: Socket, socket_to: Socket,
118async def _intercept_drop(
119 loop: EvLoop, socket_from: Socket, socket_to: Socket,
121 await loop.sock_recv(socket_from, RECV_MAX_SIZE)
124async def _intercept_delay(
125 delay: float, loop: EvLoop, socket_from: Socket, socket_to: Socket,
127 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
128 await asyncio.sleep(delay)
129 await loop.sock_sendall(socket_to, data)
132async def _intercept_close_on_data(
133 loop: EvLoop, socket_from: Socket, socket_to: Socket,
135 await loop.sock_recv(socket_from, 1)
136 raise GateInterceptException(
'Closing socket on data')
139async def _intercept_corrupt(
140 loop: EvLoop, socket_from: Socket, socket_to: Socket,
142 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
143 await loop.sock_sendall(socket_to, bytearray([
not x
for x
in data]))
146class _InterceptBpsLimit:
147 def __init__(self, bytes_per_second: float):
148 assert bytes_per_second >= 1
149 self._bytes_per_second = bytes_per_second
150 self._time_last_added = 0.0
151 self._bytes_left = self._bytes_per_second
153 def _update_limit(self) -> None:
154 current_time = time.monotonic()
155 elapsed = current_time - self._time_last_added
156 bytes_addition = self._bytes_per_second * elapsed
157 if bytes_addition > 0:
158 self._bytes_left += bytes_addition
159 self._time_last_added = current_time
161 if self._bytes_left > self._bytes_per_second:
162 self._bytes_left = self._bytes_per_second
165 self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
169 bytes_to_recv = min(int(self._bytes_left), RECV_MAX_SIZE)
170 if bytes_to_recv > 0:
171 data = await loop.sock_recv(socket_from, bytes_to_recv)
172 self._bytes_left -= len(data)
174 await loop.sock_sendall(socket_to, data)
176 logger.info(
'Socket hits the bytes per second limit')
177 await asyncio.sleep(1.0 / self._bytes_per_second)
180class _InterceptTimeLimit:
181 def __init__(self, timeout: float, jitter: float):
182 self._sockets: typing.Dict[Socket, float] = {}
183 assert timeout >= 0.0
184 self._timeout = timeout
186 self._jitter = jitter
188 def raise_if_timed_out(self, socket_from: Socket) ->
None:
189 if socket_from
not in self._sockets:
190 jitter = self._jitter * random.random()
191 expire_at = time.monotonic() + self._timeout + jitter
192 self._sockets[socket_from] = expire_at
194 if self._sockets[socket_from] <= time.monotonic():
195 del self._sockets[socket_from]
196 raise GateInterceptException(
'Socket hits the time limit')
199 self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
201 self.raise_if_timed_out(socket_from)
202 await _intercept_ok(loop, socket_from, socket_to)
205class _InterceptSmallerParts:
206 def __init__(self, max_size: int, sleep_per_packet: float):
208 self._max_size = max_size
209 self._sleep_per_packet = sleep_per_packet
212 self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
214 incoming_size = _incoming_data_size(socket_from)
215 chunk_size = min(incoming_size, self._max_size)
216 data = await loop.sock_recv(socket_from, chunk_size)
217 await asyncio.sleep(self._sleep_per_packet)
218 await loop.sock_sendall(socket_to, data)
221class _InterceptConcatPackets:
222 def __init__(self, packet_size: int):
223 assert packet_size >= 0
224 self._packet_size = packet_size
225 self._expire_at: typing.Optional[float] =
None
228 self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
230 if self._expire_at
is None:
231 self._expire_at = time.monotonic() + MAX_DELAY
233 if self._expire_at <= time.monotonic():
235 f
'Failed to make a packet of sufficient size in {MAX_DELAY} '
236 'seconds. Check the test logic, it should end with checking '
237 'that the data was sent and by calling TcpGate function '
238 'to_client_pass() to pass the remaining packets.',
242 incoming_size = _incoming_data_size(socket_from)
243 if incoming_size >= self._packet_size:
244 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
245 await loop.sock_sendall(socket_to, data)
246 self._expire_at =
None
249class _InterceptBytesLimit:
250 def __init__(self, bytes_limit: int, gate:
'BaseGate'):
251 assert bytes_limit >= 0
252 self._bytes_limit = bytes_limit
253 self._bytes_remain = self._bytes_limit
257 self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
259 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
260 if self._bytes_remain <= len(data):
261 await loop.sock_sendall(socket_to, data[0 : self._bytes_remain])
262 await self._gate.sockets_close()
263 self._bytes_remain = self._bytes_limit
264 raise GateInterceptException(
'Data transmission limit reached')
266 self._bytes_remain -= len(data)
267 await loop.sock_sendall(socket_to, data)
270class _InterceptSubstitute:
271 def __init__(self, pattern: str, repl: str, encoding=
'utf-8'):
272 self._pattern = re.compile(pattern)
274 self._encoding = encoding
277 self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
279 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
281 res = self._pattern.sub(self._repl, data.decode(self._encoding))
282 data = res.encode(self._encoding)
285 await loop.sock_sendall(socket_to, data)
288async def _cancel_and_join(task: typing.Optional[asyncio.Task]) ->
None:
289 if not task
or task.cancelled():
295 except asyncio.CancelledError:
297 except Exception
as exc:
298 logger.error(
'Exception in _cancel_and_join: %s', exc)
306 client: socket.socket,
307 server: socket.socket,
308 to_server_intercept: Interceptor,
309 to_client_intercept: Interceptor,
311 self._proxy_name = proxy_name
314 self._client = client
315 self._server = server
317 self._to_server_intercept: Interceptor = to_server_intercept
318 self._to_client_intercept: Interceptor = to_client_intercept
320 self._task_to_server = asyncio.create_task(
321 self._do_pipe_channels(to_server=
True),
323 self._task_to_client = asyncio.create_task(
324 self._do_pipe_channels(to_server=
False),
327 self._finished_channels = 0
329 async def _do_pipe_channels(self, *, to_server: bool) ->
None:
331 socket_from = self._client
332 socket_to = self._server
334 socket_from = self._server
335 socket_to = self._client
344 if not _incoming_data_size(socket_from):
349 interceptor = self._to_server_intercept
351 interceptor = self._to_client_intercept
353 await interceptor(self._loop, socket_from, socket_to)
355 except GateInterceptException
as exc:
356 logger.info(
'In "%s": %s', self._proxy_name, exc)
357 except socket.error
as exc:
358 logger.error(
'Exception in "%s": %s', self._proxy_name, exc)
360 self._finished_channels += 1
361 if self._finished_channels == 2:
364 logger.info(
'"%s" closes %s', self._proxy_name, self.info())
365 self._close_socket(self._client)
366 self._close_socket(self._server)
368 assert self._finished_channels == 1
370 self._task_to_client.cancel()
372 self._task_to_server.cancel()
374 def set_to_server_interceptor(self, interceptor: Interceptor) ->
None:
375 self._to_server_intercept = interceptor
377 def set_to_client_interceptor(self, interceptor: Interceptor) ->
None:
378 self._to_client_intercept = interceptor
380 def _close_socket(self, self_socket: Socket) ->
None:
381 assert self_socket
in {self._client, self._server}
384 except socket.error
as exc:
386 'Exception in "%s" on closing %s: %s',
388 'client' if self_socket == self._client
else 'server',
392 async def shutdown(self) -> None:
393 for task
in {self._task_to_client, self._task_to_server}:
394 await _cancel_and_join(task)
396 def is_active(self) -> bool:
398 not self._task_to_client.done()
or not self._task_to_server.done()
401 def info(self) -> str:
402 if not self.is_active():
406 f
'client fd={self._client.fileno()} <=> '
407 f
'server fd={self._server.fileno()}'