3 import os
, sys
, re
, subprocess
, tempfile
4 from optparse
import OptionParser
6 import MySQLdb
, MySQLdb
.constants
.ER
7 from autotest_lib
.client
.common_lib
import global_config
, utils
8 from autotest_lib
.database
import database_connection
10 MIGRATE_TABLE
= 'migrate_info'
12 _AUTODIR
= os
.path
.join(os
.path
.dirname(__file__
), '..')
14 'AUTOTEST_WEB' : os
.path
.join(_AUTODIR
, 'frontend', 'migrations'),
15 'TKO' : os
.path
.join(_AUTODIR
, 'tko', 'migrations'),
17 _DEFAULT_MIGRATIONS_DIR
= 'migrations' # use CWD
19 class Migration(object):
20 _UP_ATTRIBUTES
= ('migrate_up', 'UP_SQL')
21 _DOWN_ATTRIBUTES
= ('migrate_down', 'DOWN_SQL')
23 def __init__(self
, name
, version
, module
):
25 self
.version
= version
27 self
._check
_attributes
(self
._UP
_ATTRIBUTES
)
28 self
._check
_attributes
(self
._DOWN
_ATTRIBUTES
)
32 def from_file(cls
, filename
):
33 version
= int(filename
[:3])
35 module
= __import__(name
, globals(), locals(), [])
36 return cls(name
, version
, module
)
39 def _check_attributes(self
, attributes
):
40 method_name
, sql_name
= attributes
41 assert (hasattr(self
.module
, method_name
) or
42 hasattr(self
.module
, sql_name
))
45 def _execute_migration(self
, attributes
, manager
):
46 method_name
, sql_name
= attributes
47 method
= getattr(self
.module
, method_name
, None)
49 assert callable(method
)
52 sql
= getattr(self
.module
, sql_name
)
53 assert isinstance(sql
, basestring
)
54 manager
.execute_script(sql
)
57 def migrate_up(self
, manager
):
58 self
._execute
_migration
(self
._UP
_ATTRIBUTES
, manager
)
61 def migrate_down(self
, manager
):
62 self
._execute
_migration
(self
._DOWN
_ATTRIBUTES
, manager
)
65 class MigrationManager(object):
70 def __init__(self
, database_connection
, migrations_dir
=None, force
=False):
71 self
._database
= database_connection
73 # A boolean, this will only be set to True if this migration should be
74 # simulated rather than actually taken. For use with migrations that
75 # may make destructive queries
77 self
._set
_migrations
_dir
(migrations_dir
)
80 def _set_migrations_dir(self
, migrations_dir
=None):
81 config_section
= self
._config
_section
()
82 if migrations_dir
is None:
83 migrations_dir
= os
.path
.abspath(
84 _MIGRATIONS_DIRS
.get(config_section
, _DEFAULT_MIGRATIONS_DIR
))
85 self
.migrations_dir
= migrations_dir
86 sys
.path
.append(migrations_dir
)
87 assert os
.path
.exists(migrations_dir
), migrations_dir
+ " doesn't exist"
90 def _config_section(self
):
91 return self
._database
.global_config_section
94 def get_db_name(self
):
95 return self
._database
.get_database_info()['db_name']
98 def execute(self
, query
, *parameters
):
99 return self
._database
.execute(query
, parameters
)
102 def execute_script(self
, script
):
103 sql_statements
= [statement
.strip()
104 for statement
in script
.split(';')
105 if statement
.strip()]
106 for statement
in sql_statements
:
107 self
.execute(statement
)
110 def check_migrate_table_exists(self
):
112 self
.execute("SELECT * FROM %s" % MIGRATE_TABLE
)
114 except self
._database
.DatabaseError
, exc
:
115 # we can't check for more specifics due to differences between DB
116 # backends (we can't even check for a subclass of DatabaseError)
120 def create_migrate_table(self
):
121 if not self
.check_migrate_table_exists():
122 self
.execute("CREATE TABLE %s (`version` integer)" %
125 self
.execute("DELETE FROM %s" % MIGRATE_TABLE
)
126 self
.execute("INSERT INTO %s VALUES (0)" % MIGRATE_TABLE
)
127 assert self
._database
.rowcount
== 1
130 def set_db_version(self
, version
):
131 assert isinstance(version
, int)
132 self
.execute("UPDATE %s SET version=%%s" % MIGRATE_TABLE
,
134 assert self
._database
.rowcount
== 1
137 def get_db_version(self
):
138 if not self
.check_migrate_table_exists():
140 rows
= self
.execute("SELECT * FROM %s" % MIGRATE_TABLE
)
143 assert len(rows
) == 1 and len(rows
[0]) == 1
147 def get_migrations(self
, minimum_version
=None, maximum_version
=None):
148 migrate_files
= [filename
for filename
149 in os
.listdir(self
.migrations_dir
)
150 if re
.match(r
'^\d\d\d_.*\.py$', filename
)]
152 migrations
= [Migration
.from_file(filename
)
153 for filename
in migrate_files
]
154 if minimum_version
is not None:
155 migrations
= [migration
for migration
in migrations
156 if migration
.version
>= minimum_version
]
157 if maximum_version
is not None:
158 migrations
= [migration
for migration
in migrations
159 if migration
.version
<= maximum_version
]
163 def do_migration(self
, migration
, migrate_up
=True):
164 print 'Applying migration %s' % migration
.name
, # no newline
167 assert self
.get_db_version() == migration
.version
- 1
168 migration
.migrate_up(self
)
169 new_version
= migration
.version
172 assert self
.get_db_version() == migration
.version
173 migration
.migrate_down(self
)
174 new_version
= migration
.version
- 1
175 self
.set_db_version(new_version
)
178 def migrate_to_version(self
, version
):
179 current_version
= self
.get_db_version()
180 if current_version
== 0 and self
._config
_section
() == 'AUTOTEST_WEB':
181 self
._migrate
_from
_base
()
182 current_version
= self
.get_db_version()
184 if current_version
< version
:
185 lower
, upper
= current_version
, version
188 lower
, upper
= version
, current_version
191 migrations
= self
.get_migrations(lower
+ 1, upper
)
194 for migration
in migrations
:
195 self
.do_migration(migration
, migrate_up
)
197 assert self
.get_db_version() == version
198 print 'At version', version
201 def _migrate_from_base(self
):
202 self
.confirm_initialization()
204 migration_script
= utils
.read_file(
205 os
.path
.join(os
.path
.dirname(__file__
), 'schema_051.sql'))
206 migration_script
= migration_script
% (
207 dict(username
=self
._database
.get_database_info()['username']))
208 self
.execute_script(migration_script
)
210 self
.create_migrate_table()
211 self
.set_db_version(51)
214 def confirm_initialization(self
):
216 response
= raw_input(
217 'Your %s database does not appear to be initialized. Do you '
218 'want to recreate it (this will result in loss of any existing '
219 'data) (yes/No)? ' % self
.get_db_name())
220 if response
!= 'yes':
221 raise Exception('User has chosen to abort migration')
224 def get_latest_version(self
):
225 migrations
= self
.get_migrations()
226 return migrations
[-1].version
229 def migrate_to_latest(self
):
230 latest_version
= self
.get_latest_version()
231 self
.migrate_to_version(latest_version
)
234 def initialize_test_db(self
):
235 db_name
= self
.get_db_name()
236 test_db_name
= 'test_' + db_name
237 # first, connect to no DB so we can create a test DB
238 self
._database
.connect(db_name
='')
239 print 'Creating test DB', test_db_name
240 self
.execute('CREATE DATABASE ' + test_db_name
)
241 self
._database
.disconnect()
242 # now connect to the test DB
243 self
._database
.connect(db_name
=test_db_name
)
246 def remove_test_db(self
):
247 print 'Removing test DB'
248 self
.execute('DROP DATABASE ' + self
.get_db_name())
249 # reset connection back to real DB
250 self
._database
.disconnect()
251 self
._database
.connect()
254 def get_mysql_args(self
):
255 return ('-u %(username)s -p%(password)s -h %(host)s %(db_name)s' %
256 self
._database
.get_database_info())
259 def migrate_to_version_or_latest(self
, version
):
261 self
.migrate_to_latest()
263 self
.migrate_to_version(version
)
266 def do_sync_db(self
, version
=None):
267 print 'Migration starting for database', self
.get_db_name()
268 self
.migrate_to_version_or_latest(version
)
269 print 'Migration complete'
272 def test_sync_db(self
, version
=None):
274 Create a fresh DB and run all migrations on it.
276 self
.initialize_test_db()
278 print 'Starting migration test on DB', self
.get_db_name()
279 self
.migrate_to_version_or_latest(version
)
280 # show schema to the user
281 os
.system('mysqldump %s --no-data=true '
282 '--add-drop-table=false' %
283 self
.get_mysql_args())
285 self
.remove_test_db()
286 print 'Test finished successfully'
289 def simulate_sync_db(self
, version
=None):
291 Create a fresh DB, copy the existing DB to it, and then
292 try to synchronize it.
294 db_version
= self
.get_db_version()
295 # don't do anything if we're already at the latest version
296 if db_version
== self
.get_latest_version():
297 print 'Skipping simulation, already at latest version'
300 self
.initialize_and_fill_test_db()
302 print 'Starting migration test on DB', self
.get_db_name()
303 self
.migrate_to_version_or_latest(version
)
305 self
.remove_test_db()
306 print 'Test finished successfully'
309 def initialize_and_fill_test_db(self
):
310 print 'Dumping existing data'
311 dump_fd
, dump_file
= tempfile
.mkstemp('.migrate_dump')
312 os
.system('mysqldump %s >%s' %
313 (self
.get_mysql_args(), dump_file
))
315 self
.initialize_test_db()
316 print 'Filling in test DB'
317 os
.system('mysql %s <%s' % (self
.get_mysql_args(), dump_file
))
323 %s [options] sync|test|simulate|safesync [version]
325 -d --database Which database to act on
326 -a --action Which action to perform"""\
331 parser
= OptionParser()
332 parser
.add_option("-d", "--database",
333 help="which database to act on",
335 default
="AUTOTEST_WEB")
336 parser
.add_option("-a", "--action", help="what action to perform",
338 parser
.add_option("-f", "--force", help="don't ask for confirmation",
340 parser
.add_option('--debug', help='print all DB queries',
342 (options
, args
) = parser
.parse_args()
343 manager
= get_migration_manager(db_name
=options
.database
,
344 debug
=options
.debug
, force
=options
.force
)
348 version
= int(args
[1])
351 if args
[0] == 'sync':
352 manager
.do_sync_db(version
)
353 elif args
[0] == 'test':
354 manager
.simulate
=True
355 manager
.test_sync_db(version
)
356 elif args
[0] == 'simulate':
357 manager
.simulate
=True
358 manager
.simulate_sync_db(version
)
359 elif args
[0] == 'safesync':
360 print 'Simluating migration'
361 manager
.simulate
=True
362 manager
.simulate_sync_db(version
)
363 print 'Performing real migration'
364 manager
.simulate
=False
365 manager
.do_sync_db(version
)
373 def get_migration_manager(db_name
, debug
, force
):
374 database
= database_connection
.DatabaseConnection(db_name
)
375 database
.debug
= debug
376 database
.reconnect_enabled
= False
378 return MigrationManager(database
, force
=force
)
381 if __name__
== '__main__':