userver: /data/code/userver/testsuite/pytest_plugins/pytest_userver/plugins/ydb/client.py Source File
Loading...
Searching...
No Matches
client.py
1import ydb as ydb_native_client
2
3
5 def __init__(self, endpoint, database):
6 self._driver = self._init_driver(endpoint, database)
7 self._database = database
8 self._session = self._driver.table_client.session().create()
9
10 def execute(self, query):
11 return self._session.transaction().execute(query, commit_tx=True)
12
13 @property
14 def session(self):
15 return self._session
16
17 @property
18 def database(self):
19 return self._database
20
21 @staticmethod
22 def _init_driver(endpoint, database):
23 config = ydb_native_client.DriverConfig(
24 endpoint=endpoint, database=database, auth_token='',
25 )
26 driver = ydb_native_client.Driver(config)
27 driver.wait(timeout=30)
28 return driver
29
30
31def _prepare_column(column, version=None):
32 column_type = None
33 if version is None or version == 1:
34 column_type = ydb_native_client.OptionalType(
35 getattr(ydb_native_client.PrimitiveType, column['type']),
36 )
37 elif column['type'][-1] == '?':
38 column_type = ydb_native_client.OptionalType(
39 getattr(ydb_native_client.PrimitiveType, column['type'][:-1]),
40 )
41 else:
42 column_type = getattr(ydb_native_client.PrimitiveType, column['type'])
43
44 return ydb_native_client.Column(column['name'], column_type)
45
46
47def _prepare_index(index):
48 return ydb_native_client.TableIndex(index['name']).with_index_columns(
49 *tuple(index['index_columns']),
50 )
51
52
53def create_table(client, schema):
54 version = schema.get('syntax_version', None)
55 client.session.create_table(
56 '/{}/{}'.format(client.database, schema['path']),
57 ydb_native_client.TableDescription()
58 .with_primary_keys(*schema['primary_key'])
59 .with_columns(
60 *[_prepare_column(column, version) for column in schema['schema']],
61 )
62 .with_indexes(
63 *[_prepare_index(index) for index in schema.get('indexes', [])],
64 ),
65 )
66
67
68def drop_table(client, table):
69 try:
70 client.session.drop_table('/{}/{}'.format(client.database, table))
71 except: # noqa pylint: disable=bare-except
72 pass