import asyncio
import contextlib
from datetime import datetime, timezone
from logging import getLogger

import sentry_sdk
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession

from src import models
from src.config import settings
from src.dependencies import tracked_db
from src.schemas import DreamType
from src.utils.work_unit import construct_work_unit_key, parse_work_unit_key

logger = getLogger(__name__)


_dream_scheduler: "DreamScheduler | None" = None


def set_dream_scheduler(dream_scheduler: "DreamScheduler") -> None:
    """Set the global dream scheduler reference."""
    global _dream_scheduler
    _dream_scheduler = dream_scheduler


def get_dream_scheduler() -> "DreamScheduler | None":
    """Get the global dream scheduler reference."""
    return _dream_scheduler


class DreamScheduler:
    _instance: "DreamScheduler | None" = None
    _initialized: bool = False

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

    def __init__(self):
        # Only initialize once
        if not DreamScheduler._initialized:
            self.pending_dreams: dict[str, asyncio.Task[None]] = {}
            DreamScheduler._initialized = True

    @classmethod
    def reset_singleton(cls) -> None:
        """Reset the singleton instance. Only use this in tests."""
        cls._instance = None
        cls._initialized = False

    async def schedule_dream(
        self,
        work_unit_key: str,
        workspace_name: str,
        delay_minutes: int,
        dream_type: DreamType,
        *,
        observer: str,
        observed: str,
    ) -> None:
        """Schedule a dream for a collection after a delay."""
        if not settings.DREAM.ENABLED:
            return

        # Cancel any existing dream for this collection
        await self.cancel_dream(work_unit_key)

        task = asyncio.create_task(
            self._delayed_dream(
                work_unit_key,
                workspace_name,
                delay_minutes,
                dream_type,
                observer=observer,
                observed=observed,
            )
        )
        self.pending_dreams[work_unit_key] = task
        task.add_done_callback(lambda t: self.pending_dreams.pop(work_unit_key, None))

    async def cancel_dream(self, work_unit_key: str) -> bool:
        """Cancel a pending dream. Returns True if a dream was cancelled."""
        if work_unit_key in self.pending_dreams:
            task = self.pending_dreams.pop(work_unit_key)
            task.cancel()
            # Wait for the task to actually finish (including its done callback)
            with contextlib.suppress(asyncio.CancelledError):
                await task
            return True
        return False

    async def cancel_dreams_for_observed(
        self, workspace_name: str, observed: str
    ) -> set[str]:
        """
        Cancel all pending dreams where the observed peer matches.

        This handles both self-observation (observer=observed) and peer-to-peer
        observation (observer!=observed) dreams.

        Args:
            workspace_name: The workspace to match
            observed: The observed peer name to match

        Returns:
            Set of work_unit_keys that were cancelled
        """
        cancelled: set[str] = set()

        # Collect keys to cancel (can't modify dict while iterating)
        keys_to_cancel: list[str] = []
        for work_unit_key in self.pending_dreams:
            parsed = parse_work_unit_key(work_unit_key)
            if parsed.workspace_name == workspace_name and parsed.observed == observed:
                keys_to_cancel.append(work_unit_key)

        # Cancel each matching dream
        for key in keys_to_cancel:
            if await self.cancel_dream(key):
                cancelled.add(key)

        return cancelled

    async def _delayed_dream(
        self,
        work_unit_key: str,
        workspace_name: str,
        delay_minutes: int,
        dream_type: DreamType,
        *,
        observer: str,
        observed: str,
    ) -> None:
        try:
            await asyncio.sleep(delay_minutes * 60)

            await self.execute_dream(
                workspace_name,
                dream_type,
                observer=observer,
                observed=observed,
            )
            logger.info("Executed dream for %s", work_unit_key)

        except asyncio.CancelledError:
            logger.debug("Dream task cancelled for %s", work_unit_key)
        except Exception as e:
            logger.error("Error in delayed dream for %s: %s", work_unit_key, e)
            if settings.SENTRY.ENABLED:
                sentry_sdk.capture_exception(e)

    async def execute_dream(
        self,
        workspace_name: str,
        dream_type: DreamType,
        *,
        observer: str,
        observed: str,
    ) -> None:
        """Execute the dream by enqueueing it and updating collection metadata."""
        # Import here to avoid circular dependency
        from src import crud
        from src.deriver.enqueue import enqueue_dream
        from src.utils.config_helpers import get_configuration

        # Find the most recent session and get current document count
        async with tracked_db("dream_session_lookup") as db:
            stmt = (
                select(models.Document.session_name)
                .where(
                    models.Document.workspace_name == workspace_name,
                    models.Document.observer == observer,
                    models.Document.observed == observed,
                )
                .order_by(models.Document.created_at.desc())
                .limit(1)
            )
            session_name = await db.scalar(stmt)

            if not session_name:
                logger.warning(
                    f"No documents found for {workspace_name}/{observer}/{observed}, skipping dream"
                )
                return

            # Get current document count at execution time (not stale from scheduling)
            count_stmt = select(func.count(models.Document.id)).where(
                models.Document.workspace_name == workspace_name,
                models.Document.observer == observer,
                models.Document.observed == observed,
            )
            current_document_count = int(await db.scalar(count_stmt) or 0)

            session = await crud.get_session(
                db, workspace_name=workspace_name, session_name=session_name
            )
            workspace = await crud.get_workspace(db, workspace_name=workspace_name)

            configuration = get_configuration(None, session, workspace)

            if not configuration.dream.enabled:
                logger.info(
                    f"Dreams disabled for {workspace_name}/{session_name}, skipping dream"
                )
                return

        await enqueue_dream(
            workspace_name,
            observer=observer,
            observed=observed,
            dream_type=dream_type,
            document_count=current_document_count,
            session_name=session_name,
        )

    async def shutdown(self) -> None:
        """Cancel all pending dreams during shutdown."""
        if self.pending_dreams:
            logger.info(f"Cancelling {len(self.pending_dreams)} pending dreams...")
            for task in self.pending_dreams.values():
                task.cancel()
            await asyncio.gather(*self.pending_dreams.values(), return_exceptions=True)
            self.pending_dreams.clear()


async def check_and_schedule_dream(
    db: AsyncSession,
    collection: models.Collection,
) -> bool:
    """
    Check if a collection has reached the document threshold and schedule a timer-based dream.

    This function only schedules a timer-based dream if:
    1. Dreams are enabled
    2. Document threshold is reached
    3. Minimum hours between dreams have passed
    4. No dream is already scheduled for this collection

    Args:
        db: Database session
        collection: Collection model to check

    Returns:
        True if a dream timer was scheduled, False otherwise
    """
    if not settings.DREAM.ENABLED:
        return False

    # Get dream metadata from internal_metadata
    dream_metadata = collection.internal_metadata.get("dream", {})
    last_dream_document_count = dream_metadata.get("last_dream_document_count", 0)
    last_dream_at = dream_metadata.get("last_dream_at")

    # Count current documents in the collection
    count_stmt = select(func.count(models.Document.id)).where(
        models.Document.workspace_name == collection.workspace_name,
        models.Document.observer == collection.observer,
        models.Document.observed == collection.observed,
    )
    current_document_count = int(await db.scalar(count_stmt) or 0)

    # Calculate documents added since last dream
    documents_since_last_dream = current_document_count - last_dream_document_count

    logger.debug(
        "Dream check",
        extra={
            "workspace_name": collection.workspace_name,
            "observer": collection.observer,
            "observed": collection.observed,
            "current_document_count": current_document_count,
            "last_dream_document_count": last_dream_document_count,
            "documents_since_last_dream": documents_since_last_dream,
            "document_threshold": settings.DREAM.DOCUMENT_THRESHOLD,
        },
    )

    # Only schedule timer if document threshold is reached
    if documents_since_last_dream >= settings.DREAM.DOCUMENT_THRESHOLD:
        # Check if we're within minimum hours between dreams
        if last_dream_at:
            try:
                last_dream_time = datetime.fromisoformat(last_dream_at)
                hours_since_last_dream = (
                    datetime.now(timezone.utc) - last_dream_time
                ).total_seconds() / 3600

                if hours_since_last_dream < settings.DREAM.MIN_HOURS_BETWEEN_DREAMS:
                    logger.info(
                        f"Skipping dream for {collection.observer}/{collection.observed}: only {hours_since_last_dream:.1f} hours "
                        + f"since last dream (minimum: {settings.DREAM.MIN_HOURS_BETWEEN_DREAMS})"
                    )
                    return False
            except (ValueError, TypeError) as e:
                logger.warning(
                    f"Invalid last_dream_at timestamp: {last_dream_at}, error: {e}"
                )

        dream_scheduler = get_dream_scheduler()
        if dream_scheduler:
            enabled_dream_types = settings.DREAM.ENABLED_TYPES
            for dream_type in enabled_dream_types:
                # Include dream_type in key so each dream type can be tracked independently
                dream_work_unit_key = construct_work_unit_key(
                    collection.workspace_name,
                    {
                        "task_type": "dream",
                        "observer": collection.observer,
                        "observed": collection.observed,
                        "dream_type": dream_type,
                    },
                )
                await dream_scheduler.schedule_dream(
                    dream_work_unit_key,
                    collection.workspace_name,
                    settings.DREAM.IDLE_TIMEOUT_MINUTES,
                    dream_type=DreamType(dream_type),
                    observer=collection.observer,
                    observed=collection.observed,
                )
                logger.debug(
                    "Scheduled dream",
                    extra={
                        "workspace_name": collection.workspace_name,
                        "observer": collection.observer,
                        "observed": collection.observed,
                        "documents_since_last_dream": documents_since_last_dream,
                        "document_threshold": settings.DREAM.DOCUMENT_THRESHOLD,
                        "dream_type": dream_type,
                    },
                )
            return True

    return False
