26 Class that describes the route for TcpGate or UdpGate.
28 Use `port_for_client == 0` to bind to some unused port. In that case the
29 actual address could be retrieved via BaseGate.get_sockname_for_clients().
31 @ingroup userver_testsuite
37 host_for_client: str =
'127.0.0.1'
38 port_for_client: int = 0
48logger = logging.getLogger(__name__)
51Address = typing.Tuple[str, int]
54Interceptor = typing.Callable[
55 [EvLoop, Socket, Socket], typing.Coroutine[typing.Any, typing.Any,
None],
59class GateException(Exception):
63class GateInterceptException(Exception):
67async def _yield() -> None:
72 await asyncio.sleep(_MIN_DELAY)
77) -> typing.Tuple[typing.Optional[bytes], typing.Optional[Address]]:
80 return recv_socket.recvfrom(RECV_MAX_SIZE, socket.MSG_PEEK)
81 except socket.error
as e:
83 if err
in {errno.EAGAIN, errno.EWOULDBLOCK}:
88async def _wait_for_message_task(
90) -> typing.Tuple[bytes, Address]:
92 msg, addr = _try_get_message(recv_socket)
100def _incoming_data_size(recv_socket: Socket) -> int:
101 msg, _ = _try_get_message(recv_socket)
102 return len(msg)
if msg
else 0
105async def _intercept_ok(
106 loop: EvLoop, socket_from: Socket, socket_to: Socket,
108 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
109 await loop.sock_sendall(socket_to, data)
112async def _intercept_noop(
113 loop: EvLoop, socket_from: Socket, socket_to: Socket,
118async def _intercept_delay(
119 delay: float, loop: EvLoop, socket_from: Socket, socket_to: Socket,
121 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
122 await asyncio.sleep(delay)
123 await loop.sock_sendall(socket_to, data)
126async def _intercept_close_on_data(
127 loop: EvLoop, socket_from: Socket, socket_to: Socket,
129 await loop.sock_recv(socket_from, 1)
130 raise GateInterceptException(
'Closing socket on data')
133async def _intercept_corrupt(
134 loop: EvLoop, socket_from: Socket, socket_to: Socket,
136 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
137 await loop.sock_sendall(socket_to, bytearray([
not x
for x
in data]))
140class _InterceptBpsLimit:
141 def __init__(self, bytes_per_second: float):
142 assert bytes_per_second >= 1
143 self._bytes_per_second = bytes_per_second
144 self._time_last_added = 0.0
145 self._bytes_left = self._bytes_per_second
147 def _update_limit(self) -> None:
148 current_time = time.monotonic()
149 elapsed = current_time - self._time_last_added
150 bytes_addition = self._bytes_per_second * elapsed
151 if bytes_addition > 0:
152 self._bytes_left += bytes_addition
153 self._time_last_added = current_time
155 if self._bytes_left > self._bytes_per_second:
156 self._bytes_left = self._bytes_per_second
159 self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
163 bytes_to_recv = min(int(self._bytes_left), RECV_MAX_SIZE)
164 if bytes_to_recv > 0:
165 data = await loop.sock_recv(socket_from, bytes_to_recv)
166 self._bytes_left -= len(data)
168 await loop.sock_sendall(socket_to, data)
170 logger.info(
'Socket hits the bytes per second limit')
171 await asyncio.sleep(1.0 / self._bytes_per_second)
174class _InterceptTimeLimit:
175 def __init__(self, timeout: float, jitter: float):
176 self._sockets: typing.Dict[Socket, float] = {}
177 assert timeout >= 0.0
178 self._timeout = timeout
180 self._jitter = jitter
182 def raise_if_timed_out(self, socket_from: Socket) ->
None:
183 if socket_from
not in self._sockets:
184 jitter = self._jitter * random.random()
185 expire_at = time.monotonic() + self._timeout + jitter
186 self._sockets[socket_from] = expire_at
188 if self._sockets[socket_from] <= time.monotonic():
189 del self._sockets[socket_from]
190 raise GateInterceptException(
'Socket hits the time limit')
193 self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
195 self.raise_if_timed_out(socket_from)
196 await _intercept_ok(loop, socket_from, socket_to)
199class _InterceptSmallerParts:
200 def __init__(self, max_size: int, sleep_per_packet: float):
202 self._max_size = max_size
203 self._sleep_per_packet = sleep_per_packet
206 self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
208 incoming_size = _incoming_data_size(socket_from)
209 chunk_size = min(incoming_size, self._max_size)
210 data = await loop.sock_recv(socket_from, chunk_size)
211 await asyncio.sleep(self._sleep_per_packet)
212 await loop.sock_sendall(socket_to, data)
215class _InterceptConcatPackets:
216 def __init__(self, packet_size: int):
217 assert packet_size >= 0
218 self._packet_size = packet_size
219 self._expire_at: typing.Optional[float] =
None
222 self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
224 if self._expire_at
is None:
225 self._expire_at = time.monotonic() + MAX_DELAY
227 if self._expire_at <= time.monotonic():
229 f
'Failed to make a packet of sufficient size in {MAX_DELAY} '
230 'seconds. Check the test logic, it should end with checking '
231 'that the data was sent and by calling TcpGate function '
232 'to_client_pass() to pass the remaining packets.',
236 incoming_size = _incoming_data_size(socket_from)
237 if incoming_size >= self._packet_size:
238 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
239 await loop.sock_sendall(socket_to, data)
240 self._expire_at =
None
243class _InterceptBytesLimit:
244 def __init__(self, bytes_limit: int, gate:
'BaseGate'):
245 assert bytes_limit >= 0
246 self._bytes_limit = bytes_limit
247 self._bytes_remain = self._bytes_limit
251 self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
253 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
254 if self._bytes_remain <= len(data):
255 await loop.sock_sendall(socket_to, data[0 : self._bytes_remain])
256 await self._gate.sockets_close()
257 self._bytes_remain = self._bytes_limit
258 raise GateInterceptException(
'Data transmission limit reached')
260 self._bytes_remain -= len(data)
261 await loop.sock_sendall(socket_to, data)
264class _InterceptSubstitute:
265 def __init__(self, pattern: str, repl: str, encoding=
'utf-8'):
266 self._pattern = re.compile(pattern)
268 self._encoding = encoding
271 self, loop: EvLoop, socket_from: Socket, socket_to: Socket,
273 data = await loop.sock_recv(socket_from, RECV_MAX_SIZE)
275 res = self._pattern.sub(self._repl, data.decode(self._encoding))
276 data = res.encode(self._encoding)
279 await loop.sock_sendall(socket_to, data)
282async def _cancel_and_join(task: typing.Optional[asyncio.Task]) ->
None:
283 if not task
or task.cancelled():
289 except asyncio.CancelledError:
291 except Exception
as exc:
292 logger.error(
'Exception in _cancel_and_join: %s', exc)
300 client: socket.socket,
301 server: socket.socket,
302 to_server_intercept: Interceptor,
303 to_client_intercept: Interceptor,
305 self._proxy_name = proxy_name
308 self._client = client
309 self._server = server
311 self._to_server_intercept: Interceptor = to_server_intercept
312 self._to_client_intercept: Interceptor = to_client_intercept
314 self._task_to_server = asyncio.create_task(
315 self._do_pipe_channels(to_server=
True),
317 self._task_to_client = asyncio.create_task(
318 self._do_pipe_channels(to_server=
False),
321 self._finished_channels = 0
323 async def _do_pipe_channels(self, *, to_server: bool) ->
None:
325 socket_from = self._client
326 socket_to = self._server
328 socket_from = self._server
329 socket_to = self._client
338 if not _incoming_data_size(socket_from):
343 interceptor = self._to_server_intercept
345 interceptor = self._to_client_intercept
347 await interceptor(self._loop, socket_from, socket_to)
349 except GateInterceptException
as exc:
350 logger.info(
'In "%s": %s', self._proxy_name, exc)
351 except socket.error
as exc:
352 logger.error(
'Exception in "%s": %s', self._proxy_name, exc)
354 self._finished_channels += 1
355 if self._finished_channels == 2:
358 logger.info(
'"%s" closes %s', self._proxy_name, self.info())
359 self._close_socket(self._client)
360 self._close_socket(self._server)
362 assert self._finished_channels == 1
364 self._task_to_client.cancel()
366 self._task_to_server.cancel()
368 def set_to_server_interceptor(self, interceptor: Interceptor) ->
None:
369 self._to_server_intercept = interceptor
371 def set_to_client_interceptor(self, interceptor: Interceptor) ->
None:
372 self._to_client_intercept = interceptor
374 def _close_socket(self, self_socket: Socket) ->
None:
375 assert self_socket
in {self._client, self._server}
378 except socket.error
as exc:
380 'Exception in "%s" on closing %s: %s',
382 'client' if self_socket == self._client
else 'server',
386 async def shutdown(self) -> None:
387 for task
in {self._task_to_client, self._task_to_server}:
388 await _cancel_and_join(task)
390 def is_active(self) -> bool:
392 not self._task_to_client.done()
or not self._task_to_server.done()
395 def info(self) -> str:
396 if not self.is_active():
400 f
'client fd={self._client.fileno()} <=> '
401 f
'server fd={self._server.fileno()}'