"""SQLite storage for Busbar Designer. One file (`data/busbar.db` by default; override with `BUSBAR_DB` env var). Three tables: projects — full editor state (cells + busbars + params) per project. presets — named param sets the user can apply to any project. snapshots — per-project history; auto-pruned to the last N (env SNAPSHOT_RETENTION, default 20). Connection model: opens a fresh sqlite3 connection per call inside a context manager. Cheap (< 1ms) and avoids worker-shared state issues with gunicorn. """ from __future__ import annotations import json import os import sqlite3 from contextlib import contextmanager from pathlib import Path from typing import Any, Iterator DB_PATH = Path(os.environ.get("BUSBAR_DB", "data/busbar.db")) SNAPSHOT_RETENTION = int(os.environ.get("SNAPSHOT_RETENTION", "20")) @contextmanager def _conn() -> Iterator[sqlite3.Connection]: DB_PATH.parent.mkdir(parents=True, exist_ok=True) c = sqlite3.connect(str(DB_PATH)) c.row_factory = sqlite3.Row c.execute("PRAGMA foreign_keys = ON") try: yield c c.commit() finally: c.close() def init_db() -> None: with _conn() as c: c.executescript(""" CREATE TABLE IF NOT EXISTS projects ( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, data TEXT NOT NULL, created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP ); CREATE TABLE IF NOT EXISTS presets ( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL UNIQUE, params TEXT NOT NULL, created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP ); CREATE TABLE IF NOT EXISTS snapshots ( id INTEGER PRIMARY KEY AUTOINCREMENT, project_id INTEGER NOT NULL REFERENCES projects(id) ON DELETE CASCADE, data TEXT NOT NULL, note TEXT, created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP ); CREATE INDEX IF NOT EXISTS idx_snap_proj ON snapshots(project_id, created_at DESC); """) # --------------------------------------------------------------------------- # Projects # --------------------------------------------------------------------------- def list_projects() -> list[dict]: with _conn() as c: rows = c.execute( "SELECT id, name, created_at, updated_at FROM projects " "ORDER BY updated_at DESC" ).fetchall() return [dict(r) for r in rows] def get_project(pid: int) -> dict | None: with _conn() as c: row = c.execute( "SELECT id, name, data, created_at, updated_at FROM projects WHERE id=?", (pid,), ).fetchone() if row is None: return None d = dict(row) d["data"] = json.loads(d["data"]) return d def create_project(name: str, data: dict) -> int: with _conn() as c: cur = c.execute( "INSERT INTO projects(name, data) VALUES(?,?)", (name, json.dumps(data)), ) return cur.lastrowid def update_project( pid: int, name: str | None = None, data: dict | None = None, snapshot: bool = False, note: str | None = None, ) -> bool: """Update name and/or data. If `snapshot=True`, save the prior state to history first.""" with _conn() as c: # snapshot of the CURRENT (pre-update) state — useful for auto-save checkpoints if snapshot: old = c.execute("SELECT data FROM projects WHERE id=?", (pid,)).fetchone() if old is None: return False c.execute( "INSERT INTO snapshots(project_id, data, note) VALUES(?,?,?)", (pid, old["data"], note), ) _purge_old_snapshots(c, pid) if name is None and data is None: return False fields, values = [], [] if name is not None: fields.append("name=?") values.append(name) if data is not None: fields.append("data=?") values.append(json.dumps(data)) fields.append("updated_at=CURRENT_TIMESTAMP") values.append(pid) r = c.execute( f"UPDATE projects SET {', '.join(fields)} WHERE id=?", values ) return r.rowcount > 0 def delete_project(pid: int) -> bool: with _conn() as c: r = c.execute("DELETE FROM projects WHERE id=?", (pid,)) return r.rowcount > 0 # --------------------------------------------------------------------------- # Snapshots # --------------------------------------------------------------------------- def list_snapshots(pid: int) -> list[dict]: with _conn() as c: rows = c.execute( "SELECT id, note, created_at FROM snapshots " "WHERE project_id=? ORDER BY created_at DESC", (pid,), ).fetchall() return [dict(r) for r in rows] def get_snapshot(sid: int) -> dict | None: with _conn() as c: row = c.execute( "SELECT id, project_id, data, note, created_at FROM snapshots WHERE id=?", (sid,), ).fetchone() if row is None: return None d = dict(row) d["data"] = json.loads(d["data"]) return d def restore_snapshot(sid: int) -> bool: """Copy a snapshot's data back into its parent project (preserving history).""" snap = get_snapshot(sid) if not snap: return False return update_project( snap["project_id"], data=snap["data"], snapshot=True, note="auto: before restore", ) def _purge_old_snapshots(c: sqlite3.Connection, pid: int) -> None: c.execute( """ DELETE FROM snapshots WHERE id IN ( SELECT id FROM snapshots WHERE project_id=? ORDER BY created_at DESC LIMIT -1 OFFSET ? ) """, (pid, SNAPSHOT_RETENTION), ) # --------------------------------------------------------------------------- # Presets # --------------------------------------------------------------------------- def list_presets() -> list[dict]: with _conn() as c: rows = c.execute( "SELECT id, name, params, created_at FROM presets ORDER BY name" ).fetchall() return [{**dict(r), "params": json.loads(r["params"])} for r in rows] def get_preset(pid: int) -> dict | None: with _conn() as c: row = c.execute( "SELECT id, name, params, created_at FROM presets WHERE id=?", (pid,) ).fetchone() if row is None: return None d = dict(row) d["params"] = json.loads(d["params"]) return d def create_preset(name: str, params: dict) -> int | None: with _conn() as c: try: cur = c.execute( "INSERT INTO presets(name, params) VALUES(?,?)", (name, json.dumps(params)), ) return cur.lastrowid except sqlite3.IntegrityError: return None # name UNIQUE collision def update_preset(pid: int, name: str | None = None, params: dict | None = None) -> bool: with _conn() as c: sets, vals = [], [] if name is not None: sets.append("name=?") vals.append(name) if params is not None: sets.append("params=?") vals.append(json.dumps(params)) if not sets: return False vals.append(pid) try: r = c.execute(f"UPDATE presets SET {', '.join(sets)} WHERE id=?", vals) except sqlite3.IntegrityError: return False return r.rowcount > 0 def delete_preset(pid: int) -> bool: with _conn() as c: r = c.execute("DELETE FROM presets WHERE id=?", (pid,)) return r.rowcount > 0