userver: /data/code/userver/testsuite/pytest_plugins/pytest_userver/plugins/ydb/client.py Source File
Loading...
Searching...
No Matches
client.py
1import ydb as ydb_native_client
2
3
5 def __init__(self, endpoint, database):
6 self._driver = self._init_driver(endpoint, database)
7 self._database = database
8 self._session = self._driver.table_client.session().create()
9
10 def execute(self, query):
11 return self._session.transaction().execute(query, commit_tx=True)
12
13 @property
14 def session(self):
15 return self._session
16
17 @property
18 def database(self):
19 return self._database
20
21 @staticmethod
22 def _init_driver(endpoint, database):
23 config = ydb_native_client.DriverConfig(
24 endpoint=endpoint,
25 database=database,
26 auth_token='',
27 )
28 driver = ydb_native_client.Driver(config)
29 driver.wait(timeout=30)
30 return driver
31
32
33def _prepare_column(column, version=None):
34 column_type = None
35 if version is None or version == 1:
36 column_type = ydb_native_client.OptionalType(
37 getattr(ydb_native_client.PrimitiveType, column['type']),
38 )
39 elif column['type'][-1] == '?':
40 column_type = ydb_native_client.OptionalType(
41 getattr(ydb_native_client.PrimitiveType, column['type'][:-1]),
42 )
43 else:
44 column_type = getattr(ydb_native_client.PrimitiveType, column['type'])
45
46 return ydb_native_client.Column(column['name'], column_type)
47
48
49def _prepare_index(index):
50 return ydb_native_client.TableIndex(index['name']).with_index_columns(
51 *tuple(index['index_columns']),
52 )
53
54
55def create_table(client, schema):
56 version = schema.get('syntax_version', None)
57 client.session.create_table(
58 '/{}/{}'.format(client.database, schema['path']),
59 ydb_native_client.TableDescription()
60 .with_primary_keys(*schema['primary_key'])
61 .with_columns(*[_prepare_column(column, version) for column in schema['schema']])
62 .with_indexes(*[_prepare_index(index) for index in schema.get('indexes', [])]),
63 )
64
65
66def drop_table(client, table):
67 try:
68 client.session.drop_table('/{}/{}'.format(client.database, table))
69 except: # noqa pylint: disable=bare-except
70 pass