5 from autotest_lib
.client
.common_lib
import global_config
6 from autotest_lib
.client
.common_lib
.test_utils
import mock
7 from autotest_lib
.database
import database_connection
9 _CONFIG_SECTION
= 'TKO'
16 _CONNECT_KWARGS
= dict(host
=_HOST
, username
=_USER
, password
=_PASS
,
20 class FakeDatabaseError(Exception):
24 class DatabaseConnectionTest(unittest
.TestCase
):
26 self
.god
= mock
.mock_god()
27 self
.god
.stub_function(time
, 'sleep')
31 global_config
.global_config
.reset_config_values()
35 def _get_database_connection(self
, config_section
=_CONFIG_SECTION
):
36 if config_section
== _CONFIG_SECTION
:
37 self
._override
_config
()
38 db
= database_connection
.DatabaseConnection(config_section
)
40 self
._fake
_backend
= self
.god
.create_mock_class(
41 database_connection
._GenericBackend
, 'fake_backend')
42 for exception
in database_connection
._DB
_EXCEPTIONS
:
43 setattr(self
._fake
_backend
, exception
, FakeDatabaseError
)
44 self
._fake
_backend
.rowcount
= 0
46 def get_fake_backend(db_type
):
47 self
._db
_type
= db_type
48 return self
._fake
_backend
49 self
.god
.stub_with(db
, '_get_backend', get_fake_backend
)
51 db
.reconnect_delay_sec
= _RECONNECT_DELAY
55 def _override_config(self
):
56 c
= global_config
.global_config
57 c
.override_config_value(_CONFIG_SECTION
, 'host', _HOST
)
58 c
.override_config_value(_CONFIG_SECTION
, 'user', _USER
)
59 c
.override_config_value(_CONFIG_SECTION
, 'password', _PASS
)
60 c
.override_config_value(_CONFIG_SECTION
, 'database', _DB_NAME
)
61 c
.override_config_value(_CONFIG_SECTION
, 'db_type', _DB_TYPE
)
64 def test_connect(self
):
65 db
= self
._get
_database
_connection
(config_section
=None)
66 self
._fake
_backend
.connect
.expect_call(**_CONNECT_KWARGS
)
68 db
.connect(db_type
=_DB_TYPE
, host
=_HOST
, username
=_USER
,
69 password
=_PASS
, db_name
=_DB_NAME
)
71 self
.assertEquals(self
._db
_type
, _DB_TYPE
)
72 self
.god
.check_playback()
75 def test_global_config(self
):
76 db
= self
._get
_database
_connection
()
77 self
._fake
_backend
.connect
.expect_call(**_CONNECT_KWARGS
)
81 self
.assertEquals(self
._db
_type
, _DB_TYPE
)
82 self
.god
.check_playback()
85 def _expect_reconnect(self
, fail
=False):
86 self
._fake
_backend
.disconnect
.expect_call()
87 call
= self
._fake
_backend
.connect
.expect_call(**_CONNECT_KWARGS
)
89 call
.and_raises(FakeDatabaseError())
92 def _expect_fail_and_reconnect(self
, num_reconnects
, fail_last
=False):
93 self
._fake
_backend
.connect
.expect_call(**_CONNECT_KWARGS
).and_raises(
95 for i
in xrange(num_reconnects
):
96 time
.sleep
.expect_call(_RECONNECT_DELAY
)
97 if i
< num_reconnects
- 1:
98 self
._expect
_reconnect
(fail
=True)
100 self
._expect
_reconnect
(fail
=fail_last
)
103 def test_connect_retry(self
):
104 db
= self
._get
_database
_connection
()
105 self
._expect
_fail
_and
_reconnect
(1)
108 self
.god
.check_playback()
110 self
._fake
_backend
.disconnect
.expect_call()
111 self
._expect
_fail
_and
_reconnect
(0)
112 self
.assertRaises(FakeDatabaseError
, db
.connect
,
113 try_reconnecting
=False)
114 self
.god
.check_playback()
116 db
.reconnect_enabled
= False
117 self
._fake
_backend
.disconnect
.expect_call()
118 self
._expect
_fail
_and
_reconnect
(0)
119 self
.assertRaises(FakeDatabaseError
, db
.connect
)
120 self
.god
.check_playback()
123 def test_max_reconnect(self
):
124 db
= self
._get
_database
_connection
()
125 db
.max_reconnect_attempts
= 5
126 self
._expect
_fail
_and
_reconnect
(5, fail_last
=True)
128 self
.assertRaises(FakeDatabaseError
, db
.connect
)
129 self
.god
.check_playback()
132 def test_reconnect_forever(self
):
133 db
= self
._get
_database
_connection
()
134 db
.max_reconnect_attempts
= database_connection
.RECONNECT_FOREVER
135 self
._expect
_fail
_and
_reconnect
(30)
138 self
.god
.check_playback()
141 def _simple_connect(self
, db
):
142 self
._fake
_backend
.connect
.expect_call(**_CONNECT_KWARGS
)
144 self
.god
.check_playback()
147 def test_disconnect(self
):
148 db
= self
._get
_database
_connection
()
149 self
._simple
_connect
(db
)
150 self
._fake
_backend
.disconnect
.expect_call()
153 self
.god
.check_playback()
156 def test_execute(self
):
157 db
= self
._get
_database
_connection
()
158 self
._simple
_connect
(db
)
160 self
._fake
_backend
.execute
.expect_call('query', params
)
162 db
.execute('query', params
)
163 self
.god
.check_playback()
166 def test_execute_retry(self
):
167 db
= self
._get
_database
_connection
()
168 self
._simple
_connect
(db
)
169 self
._fake
_backend
.execute
.expect_call('query', None).and_raises(
171 self
._expect
_reconnect
()
172 self
._fake
_backend
.execute
.expect_call('query', None)
175 self
.god
.check_playback()
177 self
._fake
_backend
.execute
.expect_call('query', None).and_raises(
179 self
.assertRaises(FakeDatabaseError
, db
.execute
, 'query',
180 try_reconnecting
=False)
183 if __name__
== '__main__':