from dataclasses import dataclass
from logging import getLogger
from typing import Any

from cashews import NOT_NONE
from sqlalchemy import Select, delete, exists, func, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import make_transient_to_detached

from src import models, schemas
from src.cache.client import (
    cache,
    get_cache_namespace,
    safe_cache_delete,
    safe_cache_set,
)
from src.config import settings
from src.exceptions import ConflictException, ResourceNotFoundException
from src.utils.filter import apply_filter
from src.utils.types import GetOrCreateResult
from src.vector_store import get_external_vector_store

logger = getLogger(__name__)


@dataclass
class WorkspaceDeletionResult:
    """Result of a workspace deletion including cascade counts."""

    workspace: schemas.Workspace
    peers_deleted: int
    sessions_deleted: int
    messages_deleted: int
    conclusions_deleted: int


WORKSPACE_CACHE_KEY_TEMPLATE = "v2:workspace:{workspace_name}"
WORKSPACE_LOCK_PREFIX = f"{get_cache_namespace()}:lock:v2"


def workspace_cache_key(workspace_name: str) -> str:
    """Generate cache key for workspace."""
    return (
        get_cache_namespace()
        + ":"
        + WORKSPACE_CACHE_KEY_TEMPLATE.format(workspace_name=workspace_name)
    )


@cache(
    key=WORKSPACE_CACHE_KEY_TEMPLATE,
    ttl=f"{settings.CACHE.DEFAULT_TTL_SECONDS}s",
    prefix=get_cache_namespace(),
    condition=NOT_NONE,
)
@cache.locked(
    key=WORKSPACE_CACHE_KEY_TEMPLATE,
    ttl=f"{settings.CACHE.DEFAULT_LOCK_TTL_SECONDS}s",
    prefix=WORKSPACE_LOCK_PREFIX,
)
async def _fetch_workspace(
    db: AsyncSession, workspace_name: str
) -> dict[str, Any] | None:
    """Fetch a workspace from the database and return as a plain dict for safe caching."""
    obj = await db.scalar(
        select(models.Workspace).where(models.Workspace.name == workspace_name)
    )
    if obj is None:
        return None
    return {
        "id": obj.id,
        "name": obj.name,
        "h_metadata": obj.h_metadata,
        "internal_metadata": obj.internal_metadata,
        "configuration": obj.configuration,
        "created_at": obj.created_at,
    }


async def get_or_create_workspace(
    db: AsyncSession,
    workspace: schemas.WorkspaceCreate,
    *,
    _retry: bool = False,
) -> GetOrCreateResult[models.Workspace]:
    """
    Get an existing workspace or create a new one if it doesn't exist.

    Args:
        db: Database session
        workspace: Workspace creation schema

    Returns:
        GetOrCreateResult containing the workspace and whether it was created

    Raises:
        ConflictException: If we fail to get or create the workspace
    """

    if not workspace.name:
        raise ValueError("Workspace name must be provided")

    # Check if workspace already exists
    data = await _fetch_workspace(db, workspace.name)
    if data is not None:
        # Workspace already exists
        logger.debug("Found existing workspace: %s", workspace.name)
        # Reconstruct ORM object from cached dict and merge into session
        obj = models.Workspace(**data)
        make_transient_to_detached(obj)
        existing_workspace = await db.merge(obj, load=False)
        return GetOrCreateResult(existing_workspace, created=False)

    # Workspace doesn't exist, create a new one
    honcho_workspace = models.Workspace(
        name=workspace.name,
        h_metadata=workspace.metadata,
        configuration=workspace.configuration.model_dump(exclude_none=True),
    )
    try:
        async with db.begin_nested():
            db.add(honcho_workspace)
        logger.debug("Workspace created successfully: %s", workspace.name)

        # Capture cache data eagerly so the closure holds a plain dict, not the ORM object
        _cache_key = workspace_cache_key(workspace.name)
        _cache_data = {
            "id": honcho_workspace.id,
            "name": honcho_workspace.name,
            "h_metadata": honcho_workspace.h_metadata,
            "internal_metadata": honcho_workspace.internal_metadata,
            "configuration": honcho_workspace.configuration,
            "created_at": honcho_workspace.created_at,
        }

        async def _warm_workspace_cache():
            await safe_cache_set(
                _cache_key,
                _cache_data,
                expire=settings.CACHE.DEFAULT_TTL_SECONDS,
            )

        return GetOrCreateResult(
            honcho_workspace, created=True, on_commit=_warm_workspace_cache
        )
    except IntegrityError:
        if _retry:
            raise ConflictException(
                f"Unable to create or get workspace: {workspace.name}"
            ) from None
        return await get_or_create_workspace(db, workspace, _retry=True)


async def get_all_workspaces(
    filters: dict[str, Any] | None = None,
) -> Select[tuple[models.Workspace]]:
    """
    Get all workspaces.

    Args:
        db: Database session
        filters: Filter the workspaces by a dictionary of metadata
    """
    stmt = select(models.Workspace)
    stmt = apply_filter(stmt, models.Workspace, filters)
    stmt: Select[tuple[models.Workspace]] = stmt.order_by(models.Workspace.created_at)
    return stmt


async def get_workspace(
    db: AsyncSession,
    workspace_name: str,
) -> models.Workspace:
    """
    Get an existing workspace.

    Args:
        db: Database session
        workspace_name: Name of the workspace

    Returns:
        The workspace if found or created

    Raises:
        ResourceNotFoundException: If the workspace does not exist
    """
    data = await _fetch_workspace(db, workspace_name)

    if data is None:
        raise ResourceNotFoundException(f"Workspace {workspace_name} not found")

    # Reconstruct ORM object from cached dict and merge into session
    obj = models.Workspace(**data)
    make_transient_to_detached(obj)
    existing_workspace = await db.merge(obj, load=False)

    return existing_workspace


async def update_workspace(
    db: AsyncSession, workspace_name: str, workspace: schemas.WorkspaceUpdate
) -> models.Workspace:
    """
    Get or create a workspace, then apply metadata and configuration updates.

    Provided metadata replaces the current metadata when present. Provided
    configuration keys are merged into the existing configuration instead of
    replacing it wholesale.

    Args:
        db: Database session
        workspace_name: Name of the workspace
        workspace: Workspace update schema

    Returns:
        The updated workspace

    Raises:
        ConflictException: If concurrent creation prevents fetching or creating
            the workspace
    """
    ws_result = await get_or_create_workspace(
        db,
        schemas.WorkspaceCreate(
            name=workspace_name,
            metadata=workspace.metadata or {},  # Provide empty dict if metadata is None
        ),
    )
    honcho_workspace: models.Workspace = ws_result.resource

    # Track if anything changed
    needs_update = False

    if (
        workspace.metadata is not None
        and honcho_workspace.h_metadata != workspace.metadata
    ):
        honcho_workspace.h_metadata = workspace.metadata
        needs_update = True

    if workspace.configuration is not None:
        # Merge configuration instead of replacing to preserve existing keys
        base_config = (honcho_workspace.configuration or {}).copy()
        merged_config = {
            **base_config,
            **workspace.configuration.model_dump(exclude_none=True),
        }
        if honcho_workspace.configuration != merged_config:
            honcho_workspace.configuration = merged_config
            needs_update = True

    # Early exit if unchanged
    if not needs_update:
        await db.commit()
        await ws_result.post_commit()
        logger.debug("Workspace %s unchanged, skipping update", workspace_name)
        return honcho_workspace

    await db.commit()
    await ws_result.post_commit()

    # Only invalidate if we actually updated
    cache_key = workspace_cache_key(workspace_name)
    await safe_cache_delete(cache_key)

    logger.debug("Workspace with id %s updated successfully", honcho_workspace.id)
    return honcho_workspace


async def check_no_active_sessions(db: AsyncSession, workspace_name: str) -> None:
    """
    Verify that a workspace has no active sessions.

    Args:
        db: Database session
        workspace_name: Name of the workspace

    Raises:
        ConflictException: If active sessions exist in the workspace
    """
    has_active_sessions: bool = bool(
        await db.scalar(
            select(
                exists().where(
                    models.Session.workspace_name == workspace_name,
                    models.Session.is_active == True,  # noqa: E712
                )
            )
        )
    )
    if has_active_sessions:
        raise ConflictException(
            f"Cannot delete workspace '{workspace_name}': active session(s) remain. Delete all sessions first."
        )


async def delete_workspace(
    db: AsyncSession, workspace_name: str
) -> WorkspaceDeletionResult:
    """
    Delete a workspace.

    Args:
        db: Database session
        workspace_name: Name of the workspace

    Returns:
        WorkspaceDeletionResult containing a snapshot of the deleted workspace
        and cascade counts for deleted resources
    """
    logger.warning("Deleting workspace %s", workspace_name)
    stmt = select(models.Workspace).where(models.Workspace.name == workspace_name)
    result = await db.execute(stmt)
    honcho_workspace = result.scalar_one_or_none()

    if honcho_workspace is None:
        logger.warning("Workspace %s not found", workspace_name)
        raise ResourceNotFoundException()

    # NOTE: No active session check here — that gate lives in the router.
    # This crud method is called by the background worker, where a session
    # could have been created after the user's request was accepted (202).
    # The deletion should proceed and cascade-delete any new sessions.

    # Create a snapshot of the workspace data before deletion
    workspace_snapshot = schemas.Workspace(
        name=honcho_workspace.name,
        h_metadata=honcho_workspace.h_metadata,
        configuration=honcho_workspace.configuration,
        created_at=honcho_workspace.created_at,
    )

    # Count resources before deletion for telemetry
    peers_count = int(
        await db.scalar(
            select(func.count(models.Peer.id)).where(
                models.Peer.workspace_name == workspace_name
            )
        )
        or 0
    )
    sessions_count = int(
        await db.scalar(
            select(func.count(models.Session.id)).where(
                models.Session.workspace_name == workspace_name
            )
        )
        or 0
    )
    messages_count = int(
        await db.scalar(
            select(func.count(models.Message.id)).where(
                models.Message.workspace_name == workspace_name
            )
        )
        or 0
    )
    conclusions_count = int(
        await db.scalar(
            select(func.count(models.Document.id)).where(
                models.Document.workspace_name == workspace_name
            )
        )
        or 0
    )

    # order is important here.
    # delete all active queue sessions referencing this workspace first (using work_unit_key parsing)
    # then queue items referencing this workspace

    # then embeddings
    # then documents
    # then collections
    # then messages

    # then webhook endpoints
    # then session_peers
    # then sessions
    # then peers
    # then workspace

    # Delete ActiveQueueSession entries first
    # Work unit keys have format: {task_type}:{workspace_name}:{...}
    # Extract workspace_name from position 2 (second component after splitting by ':')
    try:
        await db.execute(
            delete(models.ActiveQueueSession).where(
                func.split_part(models.ActiveQueueSession.work_unit_key, ":", 2)
                == workspace_name
            )
        )

        # Then delete QueueItem entries
        await db.execute(
            delete(models.QueueItem).where(
                models.QueueItem.workspace_name == workspace_name
            )
        )

        # Also delete any queue items that reference messages in this workspace
        # (handles race condition where deriver creates new queue items)
        message_ids_subquery = select(models.Message.id).where(
            models.Message.workspace_name == workspace_name
        )
        await db.execute(
            delete(models.QueueItem).where(
                models.QueueItem.message_id.in_(message_ids_subquery)
            )
        )

        # Get all collections for this workspace to delete their vector namespaces
        collections_result = await db.execute(
            select(models.Collection).where(
                models.Collection.workspace_name == workspace_name
            )
        )
        collections = collections_result.scalars().all()

        await db.execute(
            delete(models.MessageEmbedding).where(
                models.MessageEmbedding.workspace_name == workspace_name
            )
        )
        await db.execute(
            delete(models.Document).where(
                models.Document.workspace_name == workspace_name
            )
        )
        await db.execute(
            delete(models.Collection).where(
                models.Collection.workspace_name == workspace_name
            )
        )
        await db.execute(
            delete(models.Message).where(
                models.Message.workspace_name == workspace_name
            )
        )

        await db.execute(
            delete(models.WebhookEndpoint).where(
                models.WebhookEndpoint.workspace_name == workspace_name
            )
        )
        await db.execute(
            delete(models.SessionPeer).where(
                models.SessionPeer.workspace_name == workspace_name
            )
        )
        await db.execute(
            delete(models.Session).where(
                models.Session.workspace_name == workspace_name
            )
        )
        await db.execute(
            delete(models.Peer).where(models.Peer.workspace_name == workspace_name)
        )
        await db.delete(honcho_workspace)
        await db.commit()

        # Delete vector store namespaces for this workspace
        external_vector_store = get_external_vector_store()

        # Delete message embeddings namespace for this workspace
        if external_vector_store:
            message_namespace = external_vector_store.get_vector_namespace(
                "message", workspace_name
            )
            try:
                await external_vector_store.delete_namespace(message_namespace)
                logger.debug(
                    "Deleted message embeddings namespace %s for workspace %s",
                    message_namespace,
                    workspace_name,
                )
            except Exception as e:
                logger.warning(
                    "Failed to delete message embeddings namespace %s: %s",
                    message_namespace,
                    e,
                )

            # Delete document embeddings namespaces for each collection
            for collection in collections:
                doc_namespace = external_vector_store.get_vector_namespace(
                    "document",
                    workspace_name,
                    collection.observer,
                    collection.observed,
                )
                try:
                    await external_vector_store.delete_namespace(doc_namespace)
                    logger.debug(
                        "Deleted document namespace %s for collection %s/%s",
                        doc_namespace,
                        collection.observer,
                        collection.observed,
                    )
                except Exception as e:
                    logger.warning(
                        "Failed to delete document namespace %s: %s",
                        doc_namespace,
                        e,
                    )

        cache_key = workspace_cache_key(workspace_name)
        workspace_pattern = f"{cache_key}*"
        await cache.delete_match(workspace_pattern)

        logger.debug("Workspace %s deleted", workspace_name)
    except Exception:
        logger.exception(
            "Failed to delete workspace %s",
            workspace_name,
        )
        await db.rollback()
        raise

    return WorkspaceDeletionResult(
        workspace=workspace_snapshot,
        peers_deleted=peers_count,
        sessions_deleted=sessions_count,
        messages_deleted=messages_count,
        conclusions_deleted=conclusions_count,
    )
