76 lines
3 KiB
Python
76 lines
3 KiB
Python
import logging
|
|
from pathlib import Path
|
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
|
from sqlalchemy import text
|
|
|
|
|
|
class MigrationRunner:
|
|
MIGRATIONS_SCHEMA_NAME: str = "schema_migrations"
|
|
|
|
def __init__(self, database_url: str) -> None:
|
|
self.engine = create_async_engine(database_url)
|
|
self.migrations_dir = Path("migrations/sql")
|
|
logging.info(
|
|
f"Migrator initializer: engine = {self.engine}, migrations_dir = {self.migrations_dir}"
|
|
)
|
|
|
|
async def get_applied_migrations(self) -> set:
|
|
"""Get list of applied migrations"""
|
|
async with self.engine.begin() as conn:
|
|
# Create migrations list table if not exists
|
|
await conn.execute(
|
|
text(f"""
|
|
CREATE TABLE IF NOT EXISTS {self.MIGRATIONS_SCHEMA_NAME} (
|
|
version VARCHAR(50) PRIMARY KEY,
|
|
applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
);
|
|
""")
|
|
)
|
|
|
|
# Receiving list of applied migrations
|
|
result = await conn.execute(
|
|
text(f"SELECT version FROM {self.MIGRATIONS_SCHEMA_NAME}")
|
|
)
|
|
result_data = {row[0] for row in result.fetchall()}
|
|
logging.info(f"Received migrator data: {result_data}")
|
|
return result_data
|
|
|
|
async def run_migrations(self):
|
|
"""Run all unapplied migrations"""
|
|
applied = await self.get_applied_migrations()
|
|
logging.info(f"Applied migrations: {applied}")
|
|
|
|
# Getting all sql files from migrations_dir
|
|
# TODO: (#ToLearn) Read about Path.glob function
|
|
migration_files = self.migrations_dir.glob('.*sql')
|
|
migration_files = sorted([file for file in self.migrations_dir.glob("*.sql")])
|
|
logging.info(f"Migration files: {migration_files}")
|
|
|
|
for migration_file in migration_files:
|
|
# TODO: (#ToLearn) Read about stem property
|
|
migration_name = migration_file.stem
|
|
|
|
if migration_name not in applied:
|
|
logging.info(f"Applying migration: {migration_name}")
|
|
await self._apply_migration(migration_file, migration_name)
|
|
else:
|
|
logging.info(f"Skipping migration: {migration_name} (already applied)")
|
|
|
|
async def _apply_migration(self, migration_file: Path, migration_name: str):
|
|
"""Apply migrations from migrations/sql folder"""
|
|
with open(migration_file, "r", encoding="utf-8") as f:
|
|
sql_content = f.read()
|
|
|
|
async with self.engine.begin() as conn:
|
|
try:
|
|
statements = [s.strip() for s in sql_content.split(";") if s.strip()]
|
|
|
|
for statement in statements:
|
|
if statement:
|
|
await conn.execute(text(statement))
|
|
|
|
logging.info(f"Migration {migration_name} applied successfully")
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error applying migration {migration_name}: {e}")
|
|
raise
|