#!/usr/bin/env python3
"""Toy Aubry-Andre plus BdG exploration for phi.lanz.es.

This script is intentionally lightweight: it provides a compact, reproducible
starting point for students who want to explore the Aubry-Andre phase transition,
a simple BdG extension, inverse participation ratios, and a drift-response proxy.
It is not a full parafermion solver and should be used as a pedagogical bridge.
"""

from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
from scipy.linalg import eigh

GOLDEN_RATIO = (1.0 + np.sqrt(5.0)) / 2.0
GOLDEN_RATIO_INVERSE = 1.0 / GOLDEN_RATIO


@dataclass(frozen=True)
class AAParams:
    N: int = 89
    t: float = 1.0
    lam: float = 1.5
    beta: float = GOLDEN_RATIO_INVERSE
    phase: float = 0.0
    mu: float = 0.0
    delta: float = 0.22
    periodic: bool = False
    bridge_drift: float = 0.0


def build_aa_hamiltonian(params: AAParams) -> np.ndarray:
    sites = np.arange(params.N)
    onsite = params.lam * np.cos(2.0 * np.pi * params.beta * sites + params.phase)
    drift = params.bridge_drift * np.linspace(-1.0, 1.0, params.N)

    hamiltonian = np.zeros((params.N, params.N), dtype=float)
    np.fill_diagonal(hamiltonian, onsite + drift - params.mu)

    for idx in range(params.N - 1):
        hamiltonian[idx, idx + 1] = -params.t
        hamiltonian[idx + 1, idx] = -params.t

    if params.periodic and params.N > 2:
        hamiltonian[0, -1] = -params.t
        hamiltonian[-1, 0] = -params.t

    return hamiltonian


def build_bdg_hamiltonian(params: AAParams) -> np.ndarray:
    single = build_aa_hamiltonian(params)
    pairing = np.eye(params.N) * params.delta
    return np.block([[single, pairing], [pairing.T, -single.T]])


def inverse_participation_ratio(vector: np.ndarray) -> float:
    weight = np.abs(vector) ** 2
    norm = np.sum(weight)
    if norm == 0:
        return 0.0
    return float(np.sum(weight**2) / (norm**2))


def characterize_phase(params: AAParams) -> dict[str, float | str]:
    evals, evecs = eigh(build_aa_hamiltonian(params))
    ipr_values = np.array([inverse_participation_ratio(evecs[:, idx]) for idx in range(evecs.shape[1])])
    mean_ipr = float(np.mean(ipr_values))

    if params.lam < 1.9 * params.t:
        phase = "metallic"
    elif params.lam <= 2.1 * params.t:
        phase = "critical"
    else:
        phase = "insulating"

    return {
        "phase": phase,
        "mean_ipr": mean_ipr,
        "min_eigenvalue": float(np.min(evals)),
        "max_eigenvalue": float(np.max(evals)),
    }


def central_bdg_gap(params: AAParams) -> float:
    evals = np.sort(eigh(build_bdg_hamiltonian(params), eigvals_only=True))
    positive = evals[evals >= 0]
    negative = evals[evals < 0]
    if len(positive) == 0 or len(negative) == 0:
        return 0.0
    return float(np.min(positive) - np.max(negative))


def sweep_phase_diagram(size: int = 89) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    lambda_values = np.linspace(0.0, 4.0, 81)
    beta_offsets = np.linspace(0.0, 2.0 * np.pi, 81)
    heatmap = np.zeros((len(beta_offsets), len(lambda_values)))

    for row, phase_offset in enumerate(beta_offsets):
        for col, lam in enumerate(lambda_values):
            params = AAParams(N=size, lam=float(lam), phase=float(phase_offset))
            heatmap[row, col] = characterize_phase(params)["mean_ipr"]

    return lambda_values, beta_offsets, heatmap


def bridge_response_curve() -> tuple[np.ndarray, np.ndarray]:
    drifts = np.linspace(0.0, 0.45, 12)
    responses = []
    for drift in drifts:
        params = AAParams(lam=1.0, bridge_drift=float(drift))
        responses.append(central_bdg_gap(params))
    return drifts, np.array(responses)


def save_phase_diagram(output_dir: Path) -> None:
    lambda_values, beta_offsets, heatmap = sweep_phase_diagram()
    fig, ax = plt.subplots(figsize=(8.6, 5.2))
    mesh = ax.imshow(
        heatmap,
        aspect="auto",
        origin="lower",
        extent=[lambda_values.min(), lambda_values.max(), beta_offsets.min(), beta_offsets.max()],
        cmap="magma",
    )
    ax.axvline(2.0, color="#f6d06e", linestyle="--", linewidth=1.2, label=r"$\lambda = 2t$")
    ax.set_xlabel(r"Quasiperiodic strength $\lambda / t$")
    ax.set_ylabel(r"Phase offset")
    ax.set_title("Aubry-Andre phase diagram via mean IPR")
    ax.legend(frameon=False)
    fig.colorbar(mesh, ax=ax, label="Mean inverse participation ratio")
    fig.tight_layout()
    fig.savefig(output_dir / "aa_phase_diagram.png", dpi=180)
    plt.close(fig)


def save_eigenstate_examples(output_dir: Path) -> None:
    states = {
        "metallic": AAParams(lam=1.0),
        "critical": AAParams(lam=2.0),
        "insulating": AAParams(lam=3.0),
    }

    for label, params in states.items():
        evals, evecs = eigh(build_aa_hamiltonian(params))
        mode = np.abs(evecs[:, params.N // 2]) ** 2
        fig, ax = plt.subplots(figsize=(8.4, 3.6))
        ax.plot(mode, color="#c9a84c", linewidth=1.4)
        ax.fill_between(np.arange(len(mode)), mode, color="#c9a84c", alpha=0.18)
        ax.set_xlabel("Lattice site")
        ax.set_ylabel(r"$|\psi|^2$")
        ax.set_title(f"Representative eigenstate in the {label} regime")
        fig.tight_layout()
        fig.savefig(output_dir / f"eigenstate_{label}.png", dpi=180)
        plt.close(fig)


def save_bridge_validation(output_dir: Path) -> None:
    drifts, response = bridge_response_curve()
    fig, ax = plt.subplots(figsize=(7.8, 4.4))
    ax.plot(drifts, response, color="#3d8bcd", linewidth=1.6, marker="o", markersize=4)
    ax.set_xlabel("Injected drift parameter")
    ax.set_ylabel("Central BdG gap proxy")
    ax.set_title("Bridge hypothesis validation proxy")
    ax.grid(alpha=0.18)
    fig.tight_layout()
    fig.savefig(output_dir / "bridge_hypothesis_validation.png", dpi=180)
    plt.close(fig)


def main() -> None:
    output_dir = Path.cwd()
    save_phase_diagram(output_dir)
    save_eigenstate_examples(output_dir)
    save_bridge_validation(output_dir)
    print("Created:")
    print("- aa_phase_diagram.png")
    print("- eigenstate_metallic.png")
    print("- eigenstate_critical.png")
    print("- eigenstate_insulating.png")
    print("- bridge_hypothesis_validation.png")


if __name__ == "__main__":
    main()
