Update wiki links to the new short URL
[aur.git] / aurweb / db.py
blob02aeba3875ef766e52482516677b203fbf02ece9
1 try:
2 import mysql.connector
3 except ImportError:
4 pass
6 try:
7 import sqlite3
8 except ImportError:
9 pass
11 import aurweb.config
13 engine = None # See get_engine
16 def get_sqlalchemy_url():
17 """
18 Build an SQLAlchemy for use with create_engine based on the aurweb configuration.
19 """
20 import sqlalchemy
21 aur_db_backend = aurweb.config.get('database', 'backend')
22 if aur_db_backend == 'mysql':
23 return sqlalchemy.engine.url.URL(
24 'mysql+mysqlconnector',
25 username=aurweb.config.get('database', 'user'),
26 password=aurweb.config.get('database', 'password'),
27 host=aurweb.config.get('database', 'host'),
28 database=aurweb.config.get('database', 'name'),
29 query={
30 'unix_socket': aurweb.config.get('database', 'socket'),
31 'buffered': True,
34 elif aur_db_backend == 'sqlite':
35 return sqlalchemy.engine.url.URL(
36 'sqlite',
37 database=aurweb.config.get('database', 'name'),
39 else:
40 raise ValueError('unsupported database backend')
43 def get_engine():
44 """
45 Return the global SQLAlchemy engine.
47 The engine is created on the first call to get_engine and then stored in the
48 `engine` global variable for the next calls.
49 """
50 from sqlalchemy import create_engine
51 global engine
52 if engine is None:
53 engine = create_engine(get_sqlalchemy_url(),
54 # check_same_thread is for a SQLite technicality
55 # https://fastapi.tiangolo.com/tutorial/sql-databases/#note
56 connect_args={"check_same_thread": False})
57 return engine
60 def connect():
61 """
62 Return an SQLAlchemy connection. Connections are usually pooled. See
63 <https://docs.sqlalchemy.org/en/13/core/connections.html>.
65 Since SQLAlchemy connections are context managers too, you should use it
66 with Python’s `with` operator, or with FastAPI’s dependency injection.
67 """
68 return get_engine().connect()
71 class Connection:
72 _conn = None
73 _paramstyle = None
75 def __init__(self):
76 aur_db_backend = aurweb.config.get('database', 'backend')
78 if aur_db_backend == 'mysql':
79 aur_db_host = aurweb.config.get('database', 'host')
80 aur_db_name = aurweb.config.get('database', 'name')
81 aur_db_user = aurweb.config.get('database', 'user')
82 aur_db_pass = aurweb.config.get('database', 'password')
83 aur_db_socket = aurweb.config.get('database', 'socket')
84 self._conn = mysql.connector.connect(host=aur_db_host,
85 user=aur_db_user,
86 passwd=aur_db_pass,
87 db=aur_db_name,
88 unix_socket=aur_db_socket,
89 buffered=True)
90 self._paramstyle = mysql.connector.paramstyle
91 elif aur_db_backend == 'sqlite':
92 aur_db_name = aurweb.config.get('database', 'name')
93 self._conn = sqlite3.connect(aur_db_name)
94 self._paramstyle = sqlite3.paramstyle
95 else:
96 raise ValueError('unsupported database backend')
98 def execute(self, query, params=()):
99 if self._paramstyle in ('format', 'pyformat'):
100 query = query.replace('%', '%%').replace('?', '%s')
101 elif self._paramstyle == 'qmark':
102 pass
103 else:
104 raise ValueError('unsupported paramstyle')
106 cur = self._conn.cursor()
107 cur.execute(query, params)
109 return cur
111 def commit(self):
112 self._conn.commit()
114 def close(self):
115 self._conn.close()