userver: testsuite/pytest_plugins/pytest_userver/plugins/s3api.py
Loading...
Searching...
No Matches
testsuite/pytest_plugins/pytest_userver/plugins/s3api.py
1import collections
2import dataclasses
3import datetime as dt
4import hashlib
5import pathlib
6import sys
7from typing import Dict
8from typing import List
9from typing import Mapping
10from typing import Optional
11from typing import Union
12
13import dateutil.tz as tz
14import pytest
15
16pytest_plugins = ['pytest_userver.plugins.core']
17
18
19def pytest_configure(config):
20 config.addinivalue_line('markers', 's3: store s3 files in mock')
21
22
23@dataclasses.dataclass
24class S3Object:
25 data: bytearray
26 meta: Mapping[str, str]
27
28
29class S3MockBucketStorage:
30 def __init__(self):
31 # use Path to normalize keys (e.g. /a//file.json)
32 self._storage: Dict[pathlib.Path, S3Object] = {}
33
34 @staticmethod
35 def _generate_etag(data):
36 return hashlib.md5(data).hexdigest()
37
38 def put_object(
39 self,
40 key: str,
41 data: bytearray,
42 last_modified: Optional[Union[dt.datetime, str]] = None,
43 ):
44 key_path = pathlib.Path(key)
45 if last_modified is None:
46 # Timezone is needed for RFC 3339 timeformat used by S3
47 last_modified = (
48 dt.datetime.now().replace(tzinfo=tz.tzlocal()).isoformat()
49 )
50 elif isinstance(last_modified, dt.datetime):
51 last_modified = last_modified.isoformat()
52
53 meta = {
54 'Key': str(key_path),
55 'ETag': self._generate_etag(data),
56 'Last-Modified': last_modified,
57 'Size': str(sys.getsizeof(data)),
58 }
59 self._storage[key_path] = S3Object(data, meta)
60 return meta
61
62 def get_object(self, key: str) -> Optional[S3Object]:
63 key_path = pathlib.Path(key)
64 return self._storage.get(key_path)
65
66 def get_objects(self, parent_dir='') -> Dict[str, S3Object]:
67 all_objects = {
68 str(key_path): value for key_path, value in self._storage.items()
69 }
70
71 if not parent_dir:
72 return all_objects
73
74 return {
75 key: value
76 for key, value in all_objects.items()
77 if key.startswith(str(pathlib.Path(parent_dir)))
78 }
79
80 def delete_object(self, key) -> Optional[S3Object]:
81 key = pathlib.Path(key)
82 if key not in self._storage:
83 return None
84 return self._storage.pop(key)
85
86
87class S3HandleMock:
88 def __init__(self, mockserver, s3_mock_storage, mock_base_url):
89 self._mockserver = mockserver
90 self._base_url = mock_base_url
91 self._storage = s3_mock_storage
92
93 def _get_bucket_name(self, request):
94 return request.headers['Host'].split('.')[0]
95
96 def _extract_key(self, request):
97 return request.path[len(self._base_url) + 1 :]
98
99 def _generate_get_objects_result(
100 self,
101 s3_objects_dict: Dict[str, S3Object],
102 max_keys: int,
103 marker: Optional[str],
104 ):
105 empty_result = {'result_objects': [], 'is_truncated': False}
106 keys = list(s3_objects_dict.keys())
107 if not keys:
108 return empty_result
109
110 from_index = 0
111 if marker:
112 if marker > keys[-1]:
113 return empty_result
114 for i, key in enumerate(keys):
115 if key > marker:
116 from_index = i
117 break
118
119 result_objects = [
120 s3_objects_dict[key]
121 for key in keys[from_index : from_index + max_keys]
122 ]
123 is_truncated = from_index + max_keys >= len(keys)
124 return {'result_objects': result_objects, 'is_truncated': is_truncated}
125
126 def _generate_get_objects_xml(
127 self,
128 s3_objects: List[S3Object],
129 bucket_name: str,
130 prefix: str,
131 max_keys: Optional[int],
132 marker: Optional[str],
133 is_truncated: bool,
134 ):
135 contents = ''
136 for s3_object in s3_objects:
137 contents += f"""
138 <Contents>
139 <ETag>{s3_object.meta['ETag']}</ETag>
140 <Key>{s3_object.meta['Key']}</Key>
141 <LastModified>{s3_object.meta['Last-Modified']}</LastModified>
142 <Size>{s3_object.meta['Size']}</Size>
143 <StorageClass>STANDARD</StorageClass>
144 </Contents>
145 """
146 return f"""
147 <?xml version="1.0" encoding="UTF-8"?>
148 <ListBucketResult>
149 <Name>{bucket_name}</Name>
150 <Prefix>{prefix}</Prefix>
151 <Marker>{marker or ''}</Marker>
152 <MaxKeys>{max_keys or ''}</MaxKeys>
153 <IsTruncated>{is_truncated}</IsTruncated>
154 {contents}
155 </ListBucketResult>
156 """
157
158 def get_object(self, request):
159 key = self._extract_key(request)
160
161 bucket_storage = self._storage[self._get_bucket_name(request)]
162
163 s3_object = bucket_storage.get_object(key)
164 if not s3_object:
165 return self._mockserver.make_response('Object not found', 404)
166 return self._mockserver.make_response(
167 s3_object.data, 200, headers=s3_object.meta,
168 )
169
170 def put_object(self, request):
171 key = self._extract_key(request)
172
173 bucket_storage = self._storage[self._get_bucket_name(request)]
174
175 data = request.get_data()
176 meta = bucket_storage.put_object(key, data)
177 return self._mockserver.make_response('OK', 200, headers=meta)
178
179 def copy_object(self, request):
180 key = self._extract_key(request)
181 dest_bucket_name = self._get_bucket_name(request)
182 source_bucket_name, source_key = request.headers.get(
183 'x-amz-copy-source',
184 ).split('/', 2)[1:3]
185
186 src_bucket_storage = self._storage[source_bucket_name]
187 dst_bucket_storage = self._storage[dest_bucket_name]
188
189 data = src_bucket_storage.get_object(source_key).data
190 meta = dst_bucket_storage.put_object(key, data)
191 return self._mockserver.make_response('OK', 200, headers=meta)
192
193 def get_objects(self, request):
194 prefix = request.query['prefix']
195 # 1000 is the default value specified by aws spec
196 max_keys = int(request.query.get('max-keys', 1000))
197 marker = request.query.get('marker')
198
199 bucket_name = self._get_bucket_name(request)
200 bucket_storage = self._storage[bucket_name]
201
202 s3_objects_dict = bucket_storage.get_objects(parent_dir=prefix)
203 result = self._generate_get_objects_result(
204 s3_objects_dict=s3_objects_dict, max_keys=max_keys, marker=marker,
205 )
206 result_xml = self._generate_get_objects_xml(
207 s3_objects=result['result_objects'],
208 bucket_name=bucket_name,
209 prefix=prefix,
210 max_keys=max_keys,
211 marker=marker,
212 is_truncated=result['is_truncated'],
213 )
214 return self._mockserver.make_response(result_xml, 200)
215
216 def delete_object(self, request):
217 key = self._extract_key(request)
218
219 bucket_storage = self._storage[self._get_bucket_name(request)]
220
221 bucket_storage.delete_object(key)
222 # S3 always return 204, even if file doesn't exist
223 return self._mockserver.make_response('OK', 204)
224
225 def get_object_head(self, request):
226 key = self._extract_key(request)
227
228 bucket_storage = self._storage[self._get_bucket_name(request)]
229
230 s3_object = bucket_storage.get_object(key)
231 if not s3_object:
232 return self._mockserver.make_response('Object not found', 404)
233 return self._mockserver.make_response(
234 'OK', 200, headers=s3_object.meta,
235 )
236
237
238@pytest.fixture(name='s3_mock_storage')
239def _s3_mock_storage():
240 buckets = collections.defaultdict(S3MockBucketStorage)
241 return buckets
242
243
244@pytest.fixture(autouse=True)
245def s3_mock(mockserver, s3_mock_storage):
246 mock_base_url = '/mds-s3'
247 mock_impl = S3HandleMock(
248 mockserver=mockserver,
249 s3_mock_storage=s3_mock_storage,
250 mock_base_url=mock_base_url,
251 )
252
253 @mockserver.handler(mock_base_url, prefix=True)
254 def _mock_all(request):
255 if request.method == 'GET':
256 if 'prefix' in request.query:
257 return mock_impl.get_objects(request)
258 return mock_impl.get_object(request)
259
260 if request.method == 'PUT':
261 if request.headers.get('x-amz-copy-source', None):
262 return mock_impl.copy_object(request)
263 return mock_impl.put_object(request)
264
265 if request.method == 'DELETE':
266 return mock_impl.delete_object(request)
267
268 if request.method == 'HEAD':
269 return mock_impl.get_object_head(request)
270
271 return mockserver.make_response('Unknown or unsupported method', 404)
272
273
274@pytest.fixture(autouse=True)
275def s3_apply(request, s3_mock_storage, load):
276 def _put_files(bucket, files):
277 bucket_storage = s3_mock_storage[bucket]
278 for s3_path, file_path in files.items():
279 bucket_storage.put_object(
280 key=s3_path, data=load(file_path).encode('utf-8'),
281 )
282
283 for mark in request.node.iter_markers('s3'):
284 _put_files(*mark.args, **mark.kwargs)