3 import unittest
, tempfile
, os
6 from autotest_lib
.client
.common_lib
import global_config
7 from autotest_lib
.database
import database_connection
, migrate
9 # Which section of the global config to pull info from. We won't actually use
10 # that DB, we'll use the corresponding test DB (test_<db name>).
11 CONFIG_DB
= 'AUTOTEST_WEB'
15 class DummyMigration(object):
17 Dummy migration class that records all migrations done in a class
23 def __init__(self
, version
):
24 self
.version
= version
25 self
.name
= '%03d_test' % version
29 def get_migrations_done(cls
):
30 return cls
.migrations_done
34 def clear_migrations_done(cls
):
35 cls
.migrations_done
= []
39 def do_migration(cls
, version
, direction
):
40 cls
.migrations_done
.append((version
, direction
))
43 def migrate_up(self
, manager
):
44 self
.do_migration(self
.version
, 'up')
46 manager
.create_migrate_table()
49 def migrate_down(self
, manager
):
50 self
.do_migration(self
.version
, 'down')
53 MIGRATIONS
= [DummyMigration(n
) for n
in xrange(1, NUM_MIGRATIONS
+ 1)]
56 class TestableMigrationManager(migrate
.MigrationManager
):
57 def _set_migrations_dir(self
, migrations_dir
=None):
61 def get_migrations(self
, minimum_version
=None, maximum_version
=None):
62 minimum_version
= minimum_version
or 1
63 maximum_version
= maximum_version
or len(MIGRATIONS
)
64 return MIGRATIONS
[minimum_version
-1:maximum_version
]
67 class MigrateManagerTest(unittest
.TestCase
):
70 database_connection
.DatabaseConnection
.get_test_database())
71 self
._database
.connect()
72 self
.manager
= TestableMigrationManager(self
._database
)
73 DummyMigration
.clear_migrations_done()
77 self
._database
.disconnect()
81 self
.manager
.do_sync_db()
82 self
.assertEquals(self
.manager
.get_db_version(), NUM_MIGRATIONS
)
83 self
.assertEquals(DummyMigration
.get_migrations_done(),
84 [(1, 'up'), (2, 'up'), (3, 'up')])
86 DummyMigration
.clear_migrations_done()
87 self
.manager
.do_sync_db(0)
88 self
.assertEquals(self
.manager
.get_db_version(), 0)
89 self
.assertEquals(DummyMigration
.get_migrations_done(),
90 [(3, 'down'), (2, 'down'), (1, 'down')])
93 def test_sync_one_by_one(self
):
94 for version
in xrange(1, NUM_MIGRATIONS
+ 1):
95 self
.manager
.do_sync_db(version
)
96 self
.assertEquals(self
.manager
.get_db_version(),
99 DummyMigration
.get_migrations_done()[-1],
102 for version
in xrange(NUM_MIGRATIONS
- 1, -1, -1):
103 self
.manager
.do_sync_db(version
)
104 self
.assertEquals(self
.manager
.get_db_version(),
107 DummyMigration
.get_migrations_done()[-1],
108 (version
+ 1, 'down'))
111 def test_null_sync(self
):
112 self
.manager
.do_sync_db()
113 DummyMigration
.clear_migrations_done()
114 self
.manager
.do_sync_db()
115 self
.assertEquals(DummyMigration
.get_migrations_done(), [])
118 class DummyMigrationManager(object):
123 def execute_script(self
, script
):
124 self
.calls
.append(script
)
127 class MigrationTest(unittest
.TestCase
):
129 self
.manager
= DummyMigrationManager()
132 def _do_migration(self
, migration_module
):
133 migration
= migrate
.Migration('name', 1, migration_module
)
134 migration
.migrate_up(self
.manager
)
135 migration
.migrate_down(self
.manager
)
137 self
.assertEquals(self
.manager
.calls
, ['foo', 'bar'])
140 def test_migration_with_methods(self
):
141 class DummyMigration(object):
143 def migrate_up(manager
):
144 manager
.execute_script('foo')
148 def migrate_down(manager
):
149 manager
.execute_script('bar')
151 self
._do
_migration
(DummyMigration
)
154 def test_migration_with_strings(self
):
155 class DummyMigration(object):
159 self
._do
_migration
(DummyMigration
)
162 if __name__
== '__main__':