import asyncio
import signal
from asyncio import Task
from collections.abc import Sequence
from datetime import datetime, timedelta, timezone
from logging import getLogger
from typing import Any, NamedTuple, cast

import sentry_sdk
from dotenv import load_dotenv
from nanoid import generate as generate_nanoid
from sentry_sdk.integrations.asyncio import AsyncioIntegration
from sqlalchemy import and_, delete, or_, select, update
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.engine import CursorResult
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import func

from src import models
from src.cache.client import close_cache, init_cache
from src.config import settings
from src.dependencies import tracked_db
from src.deriver.consumer import (
    process_item,
    process_representation_batch,
)
from src.dreamer.dream_scheduler import (
    DreamScheduler,
    get_dream_scheduler,
    set_dream_scheduler,
)
from src.models import QueueItem
from src.reconciler import (
    ReconcilerScheduler,
    get_reconciler_scheduler,
    set_reconciler_scheduler,
)
from src.schemas import ResolvedConfiguration
from src.telemetry import prometheus_metrics
from src.telemetry.sentry import initialize_sentry
from src.utils.work_unit import parse_work_unit_key
from src.webhooks.events import (
    QueueEmptyEvent,
    publish_webhook_event,
)

logger = getLogger(__name__)

load_dotenv(override=True)


class WorkerOwnership(NamedTuple):
    """Represents the instance of a work unit that a worker is processing."""

    work_unit_key: str
    aqs_id: str  # The ID of the ActiveQueueSession that the worker is processing


def _detach_queue_batch_objects(
    db: AsyncSession,
    messages_context: list[models.Message],
    items_to_process: list[QueueItem],
) -> None:
    """Detach loaded batch objects so they remain usable after tracked_db exits."""
    seen: set[int] = set()
    for obj in [*messages_context, *items_to_process]:
        obj_id = id(obj)
        if obj_id in seen:
            continue
        db.expunge(obj)
        seen.add(obj_id)


def _resolve_batch_configuration(
    items_to_process: list[QueueItem],
) -> tuple[list[QueueItem], ResolvedConfiguration | None]:
    """Keep only the initial homogeneous configuration prefix for a batch."""
    if not items_to_process:
        return [], None

    raw_config = items_to_process[0].payload.get("configuration")
    resolved_config = (
        None if raw_config is None else ResolvedConfiguration.model_validate(raw_config)
    )

    valid_items: list[QueueItem] = []
    for item in items_to_process:
        item_raw_config = item.payload.get("configuration")
        item_config = (
            None
            if item_raw_config is None
            else ResolvedConfiguration.model_validate(item_raw_config)
        )
        if item_config != resolved_config:
            break
        valid_items.append(item)

    return valid_items, resolved_config


class QueueManager:
    def __init__(self):
        self.shutdown_event: asyncio.Event = asyncio.Event()
        self.active_tasks: set[asyncio.Task[None]] = set()
        self.worker_ownership: dict[str, WorkerOwnership] = {}
        self.queue_empty_flag: asyncio.Event = asyncio.Event()

        # Initialize from settings
        self.workers: int = settings.DERIVER.WORKERS
        self.semaphore: asyncio.Semaphore = asyncio.Semaphore(self.workers)

        # Get or create the singleton dream scheduler
        existing_scheduler = get_dream_scheduler()
        if existing_scheduler is None:
            self.dream_scheduler: DreamScheduler = DreamScheduler()
            set_dream_scheduler(self.dream_scheduler)
        else:
            self.dream_scheduler = existing_scheduler

        # Get or create the singleton reconciler scheduler
        existing_reconciler = get_reconciler_scheduler()
        if existing_reconciler is None:
            self.reconciler_scheduler: ReconcilerScheduler = ReconcilerScheduler()
            set_reconciler_scheduler(self.reconciler_scheduler)
        else:
            self.reconciler_scheduler = existing_reconciler

        # Initialize Sentry if enabled, using settings
        if settings.SENTRY.ENABLED:
            initialize_sentry(integrations=[AsyncioIntegration()])

    def add_task(self, task: asyncio.Task[None]) -> None:
        """Track a new task"""
        self.active_tasks.add(task)
        task.add_done_callback(self.active_tasks.discard)

    def track_worker_work_unit(
        self, worker_id: str, work_unit_key: str, aqs_id: str
    ) -> None:
        """Track a work unit owned by a specific worker"""
        self.worker_ownership[worker_id] = WorkerOwnership(work_unit_key, aqs_id)

    def untrack_worker_work_unit(self, worker_id: str, work_unit_key: str) -> None:
        """Remove a work unit from worker tracking"""
        ownership = self.worker_ownership.get(worker_id)
        if ownership and ownership.work_unit_key == work_unit_key:
            del self.worker_ownership[worker_id]

    def create_worker_id(self) -> str:
        """Generate a unique worker ID for this processing task"""
        return generate_nanoid()

    def get_total_owned_work_units(self) -> int:
        """Get the total number of work units owned by all workers"""
        return len(self.worker_ownership)

    async def initialize(self) -> None:
        """Setup signal handlers, initialize client, and start the main polling loop"""
        logger.debug(f"Initializing QueueManager with {self.workers} workers")

        # Set up signal handlers
        loop = asyncio.get_running_loop()
        signals = (signal.SIGTERM, signal.SIGINT)
        for sig in signals:
            loop.add_signal_handler(
                sig, lambda s=sig: asyncio.create_task(self.shutdown(s))
            )
        logger.debug("Signal handlers registered")

        # Start the reconciler scheduler
        try:
            await self.reconciler_scheduler.start()
        except Exception:
            logger.exception("Failed to start reconciler scheduler")

        # Run the polling loop directly in this task
        logger.debug("Starting polling loop directly")
        try:
            await self.polling_loop()
        finally:
            await self.cleanup()

    async def shutdown(self, sig: signal.Signals) -> None:
        """Handle graceful shutdown"""
        logger.info(f"Received exit signal {sig.name}...")
        self.shutdown_event.set()

        # Cancel all pending dreams
        await self.dream_scheduler.shutdown()

        # Stop the reconciler scheduler
        await self.reconciler_scheduler.shutdown()

        if self.active_tasks:
            logger.info(
                f"Waiting for {len(self.active_tasks)} active tasks to complete..."
            )
            await asyncio.gather(*self.active_tasks, return_exceptions=True)

    async def cleanup(self) -> None:
        """Clean up owned work units"""
        total_work_units = self.get_total_owned_work_units()
        if total_work_units > 0:
            logger.debug(f"Cleaning up {total_work_units} owned work units...")
            try:
                # Use the tracked_db dependency for transaction safety
                async with tracked_db("queue_cleanup") as db:
                    aqs_ids = [
                        ownership.aqs_id for ownership in self.worker_ownership.values()
                    ]
                    if aqs_ids:
                        await db.execute(
                            delete(models.ActiveQueueSession).where(
                                models.ActiveQueueSession.id.in_(aqs_ids)
                            )
                        )
                    await db.commit()
            except Exception as e:
                logger.error(f"Error during cleanup: {str(e)}")
                if settings.SENTRY.ENABLED:
                    sentry_sdk.capture_exception(e)
            finally:
                self.worker_ownership.clear()

    ##########################
    # Polling and Scheduling #
    ##########################

    async def cleanup_stale_work_units(self) -> None:
        """Clean up stale work units"""
        async with tracked_db("cleanup_stale_work_units") as db:
            cutoff = datetime.now(timezone.utc) - timedelta(
                minutes=settings.DERIVER.STALE_SESSION_TIMEOUT_MINUTES
            )

            stale_ids = (
                (
                    await db.execute(
                        select(models.ActiveQueueSession.id)
                        .where(models.ActiveQueueSession.last_updated < cutoff)
                        .order_by(models.ActiveQueueSession.last_updated)
                        .with_for_update(skip_locked=True)
                    )
                )
                .scalars()
                .all()
            )

            # Delete only the records we successfully got locks for
            if stale_ids:
                await db.execute(
                    delete(models.ActiveQueueSession).where(
                        models.ActiveQueueSession.id.in_(stale_ids)
                    )
                )
            await db.commit()

    async def get_and_claim_work_units(self) -> dict[str, str]:
        """
        Get available work units that aren't being processed.
        For representation tasks, only returns work units with accumulated tokens
        >= REPRESENTATION_BATCH_MAX_TOKENS (forced batching), unless FLUSH_ENABLED is True.
        Returns a dict mapping work_unit_key to aqs_id.
        """
        limit: int = max(0, self.workers - self.get_total_owned_work_units())
        if limit == 0:
            return {}

        batch_max_tokens = settings.DERIVER.REPRESENTATION_BATCH_MAX_TOKENS

        async with tracked_db("get_available_work_units") as db:
            representation_prefix = "representation:"
            token_stats_subq = (
                select(
                    models.QueueItem.work_unit_key,
                    func.sum(models.Message.token_count).label("total_tokens"),
                )
                .join(
                    models.Message,
                    models.QueueItem.message_id == models.Message.id,
                )
                .where(~models.QueueItem.processed)
                .where(models.QueueItem.work_unit_key.startswith(representation_prefix))
                .group_by(models.QueueItem.work_unit_key)
                .subquery()
            )

            work_units_subq = (
                select(models.QueueItem.work_unit_key)
                .where(~models.QueueItem.processed)
                .group_by(models.QueueItem.work_unit_key)
                .subquery()
            )

            query = (
                select(work_units_subq.c.work_unit_key)
                .limit(limit)
                .outerjoin(
                    token_stats_subq,
                    work_units_subq.c.work_unit_key == token_stats_subq.c.work_unit_key,
                )
                .where(
                    ~select(models.ActiveQueueSession.id)
                    .where(
                        models.ActiveQueueSession.work_unit_key
                        == work_units_subq.c.work_unit_key
                    )
                    .exists()
                )
            )

            # Apply batch threshold filter (skip if FLUSH_ENABLED is True)
            if not settings.DERIVER.FLUSH_ENABLED and batch_max_tokens > 0:
                query = query.where(
                    or_(
                        ~work_units_subq.c.work_unit_key.startswith(
                            representation_prefix
                        ),
                        func.coalesce(token_stats_subq.c.total_tokens, 0)
                        >= batch_max_tokens,
                    )
                )

            result = await db.execute(query)
            available_units = result.scalars().all()
            if not available_units:
                await db.commit()
                return {}

            claimed_mapping = await self.claim_work_units(db, available_units)
            await db.commit()

            return claimed_mapping

    async def claim_work_units(
        self, db: AsyncSession, work_unit_keys: Sequence[str]
    ) -> dict[str, str]:
        """
        Claim work units and return a mapping of work_unit_key to aqs_id.
        Returns only the work units that were successfully claimed.
        """
        values = [{"work_unit_key": key} for key in work_unit_keys]

        stmt = (
            insert(models.ActiveQueueSession)
            .values(values)
            .on_conflict_do_nothing()
            .returning(
                models.ActiveQueueSession.work_unit_key, models.ActiveQueueSession.id
            )
        )

        result = await db.execute(stmt)
        claimed_rows = result.all()
        claimed_mapping = {row[0]: row[1] for row in claimed_rows}
        logger.debug(
            f"Claimed {len(claimed_mapping)} work units: {list(claimed_mapping.keys())}"
        )
        return claimed_mapping

    async def polling_loop(self) -> None:
        """Main polling loop to find and process new work units"""
        logger.debug("Starting polling loop")
        try:
            while not self.shutdown_event.is_set():
                if self.queue_empty_flag.is_set():
                    # logger.debug("Queue empty flag set, waiting")
                    await asyncio.sleep(settings.DERIVER.POLLING_SLEEP_INTERVAL_SECONDS)
                    self.queue_empty_flag.clear()
                    continue

                # Check if we have capacity before querying
                if self.semaphore.locked():
                    # logger.debug("All workers busy, waiting")
                    await asyncio.sleep(settings.DERIVER.POLLING_SLEEP_INTERVAL_SECONDS)
                    continue

                try:
                    await self.cleanup_stale_work_units()
                    claimed_work_units = await self.get_and_claim_work_units()
                    if claimed_work_units:
                        for work_unit_key, aqs_id in claimed_work_units.items():
                            # Create a new task for processing this work unit
                            if not self.shutdown_event.is_set():
                                # Track worker ownership
                                worker_id = self.create_worker_id()
                                self.track_worker_work_unit(
                                    worker_id, work_unit_key, aqs_id
                                )

                                task: Task[None] = asyncio.create_task(
                                    self.process_work_unit(work_unit_key, worker_id)
                                )
                                self.add_task(task)
                    else:
                        self.queue_empty_flag.set()
                        await asyncio.sleep(
                            settings.DERIVER.POLLING_SLEEP_INTERVAL_SECONDS
                        )
                except Exception as e:
                    logger.exception("Error in polling loop")
                    if settings.SENTRY.ENABLED:
                        sentry_sdk.capture_exception(e)
                    # Note: rollback is handled by tracked_db dependency
                    await asyncio.sleep(settings.DERIVER.POLLING_SLEEP_INTERVAL_SECONDS)
        finally:
            logger.info("Polling loop stopped")

    ######################
    # Queue Worker Logic #
    ######################

    async def _handle_processing_error(
        self,
        error: Exception,
        items: list[QueueItem],
        work_unit_key: str,
        context: str,
    ) -> None:
        """
        Handle processing errors by marking queue items as errored, logging, and forwarding to Sentry.
        We only mark the first queue item as errored so we don't potentially throw away a batch. This allows us
        to incrementally attempt to process the batch while still maintaining progress in a work unit.

        Args:
            error: The exception that occurred
            items: The queue items that were being processed
            work_unit_key: The work unit key for the queue items
            context: Context string describing what was being processed (e.g., "processing representation batch")
        """
        error_msg = f"{error.__class__.__name__}: {str(error)}"
        try:
            if items:
                await self.mark_queue_item_as_errored(
                    items[0], work_unit_key, error_msg
                )
        except Exception as mark_error:
            logger.error(
                f"Failed to mark queue items as errored for work unit {work_unit_key}: {mark_error}",
                exc_info=True,
            )

        logger.error(
            f"Error {context} for work unit {work_unit_key}: {error}",
            exc_info=True,
        )
        if settings.SENTRY.ENABLED:
            sentry_sdk.capture_exception(error)

    async def process_work_unit(self, work_unit_key: str, worker_id: str) -> None:
        """Process all queue items for a specific work unit by routing to the correct handler."""
        logger.debug(f"Starting to process work unit {work_unit_key}")
        work_unit = parse_work_unit_key(work_unit_key)
        async with self.semaphore:
            queue_item_count = 0
            try:
                while not self.shutdown_event.is_set():
                    # Get worker ownership info for verification
                    ownership = self.worker_ownership.get(worker_id)
                    if not ownership or ownership.work_unit_key != work_unit_key:
                        logger.warning(
                            f"Worker {worker_id} lost ownership of work unit {work_unit_key}, stopping processing {work_unit_key}"
                        )
                        break
                    try:
                        if work_unit.task_type == "representation":
                            (
                                messages_context,
                                items_to_process,
                                message_level_configuration,
                            ) = await self.get_queue_item_batch(
                                work_unit.task_type, work_unit_key, ownership.aqs_id
                            )
                            logger.debug(
                                f"Worker {worker_id} retrieved {len(messages_context)} messages and {len(items_to_process)} queue items for work unit {work_unit_key} (AQS ID: {ownership.aqs_id})"
                            )
                            if not items_to_process:
                                logger.debug(
                                    f"No more queue items to process for work unit {work_unit_key} for worker {worker_id}"
                                )
                                break

                            try:
                                # Extract observers from the payload (handle both old and new format)
                                payload = items_to_process[0].payload
                                observers = payload.get("observers")
                                if observers is None:
                                    # Legacy format: single observer string
                                    legacy_observer = payload.get("observer")
                                    if legacy_observer:
                                        observers = [legacy_observer]
                                    else:
                                        observers = []

                                queue_item_message_ids = [
                                    item.message_id
                                    for item in items_to_process
                                    if item.message_id is not None
                                ]
                                await process_representation_batch(
                                    messages_context,
                                    message_level_configuration,
                                    observers=observers,
                                    observed=work_unit.observed,
                                    queue_item_message_ids=queue_item_message_ids,
                                )
                                await self.mark_queue_items_as_processed(
                                    items_to_process, work_unit_key
                                )
                                queue_item_count += len(items_to_process)
                            except Exception as e:
                                await self._handle_processing_error(
                                    e,
                                    items_to_process,
                                    work_unit_key,
                                    f"processing {work_unit.task_type} batch",
                                )

                        else:
                            queue_item = await self.get_next_queue_item(
                                work_unit.task_type, work_unit_key, ownership.aqs_id
                            )
                            if not queue_item:
                                logger.debug(
                                    f"No more queue items to process for work unit {work_unit_key} for worker {worker_id}"
                                )
                                break

                            try:
                                await process_item(queue_item)
                                await self.mark_queue_items_as_processed(
                                    [queue_item], work_unit_key
                                )
                                queue_item_count += 1
                            except Exception as e:
                                await self._handle_processing_error(
                                    e,
                                    [queue_item],
                                    work_unit_key,
                                    "processing queue item",
                                )

                    except Exception as e:
                        logger.error(
                            f"Error in processing loop for work unit {work_unit_key}: {e}",
                            exc_info=True,
                        )
                        if settings.SENTRY.ENABLED:
                            sentry_sdk.capture_exception(e)

                    # Check for shutdown after processing each batch
                    if self.shutdown_event.is_set():
                        logger.debug(
                            "Shutdown requested, stopping processing for work unit %s",
                            work_unit_key,
                        )
                        break

            finally:
                # Remove work unit from active_queue_sessions when done
                ownership: WorkerOwnership | None = self.worker_ownership.get(worker_id)
                if ownership and ownership.work_unit_key == work_unit_key:
                    removed = await self._cleanup_work_unit(
                        ownership.aqs_id, work_unit_key
                    )
                else:
                    removed = False

                self.untrack_worker_work_unit(worker_id, work_unit_key)
                if removed and queue_item_count > 0:
                    # Only publish webhook if we actually removed an active session
                    try:
                        if (
                            work_unit.task_type in ["representation", "summary"]
                            and work_unit.workspace_name is not None
                        ):
                            logger.debug(
                                f"Publishing queue.empty event for {work_unit_key} in workspace {work_unit.workspace_name}"
                            )
                            await publish_webhook_event(
                                QueueEmptyEvent(
                                    workspace_id=work_unit.workspace_name,
                                    queue_type=work_unit.task_type,
                                    session_id=work_unit.session_name,
                                    observer=work_unit.observer,
                                    observed=work_unit.observed,
                                )
                            )
                    except Exception:
                        logger.exception("Error triggering queue_empty webhook")
                else:
                    logger.debug(
                        f"Work unit {work_unit_key} already cleaned up by another worker, skipping webhook"
                    )

    @sentry_sdk.trace
    async def get_next_queue_item(
        self, task_type: str, work_unit_key: str, aqs_id: str
    ) -> QueueItem | None:
        """Get the next queue item to process for a specific work unit."""
        if task_type == "representation":
            raise ValueError(
                "representation tasks are not supported for get_next_queue_item"
            )
        async with tracked_db("get_next_queue_item") as db:
            # ActiveQueueSession conditions for worker ownership verification
            aqs_conditions = [
                models.ActiveQueueSession.work_unit_key == work_unit_key,
                models.ActiveQueueSession.id == aqs_id,
            ]

            query = (
                select(models.QueueItem)
                .join(
                    models.ActiveQueueSession,
                    models.QueueItem.work_unit_key
                    == models.ActiveQueueSession.work_unit_key,
                )
                .where(models.QueueItem.work_unit_key == work_unit_key)
                .where(~models.QueueItem.processed)
                .where(*aqs_conditions)
                .order_by(models.QueueItem.id)
                .limit(1)
            )
            result = await db.execute(query)
            queue_item = result.scalar_one_or_none()

            # Important: commit to avoid tracked_db's rollback expiring the instance
            # We rely on expire_on_commit=False to keep attributes accessible post-close
            await db.commit()
            return queue_item

    @sentry_sdk.trace
    async def get_queue_item_batch(
        self,
        task_type: str,
        work_unit_key: str,
        aqs_id: str,
    ) -> tuple[list[models.Message], list[QueueItem], ResolvedConfiguration | None]:
        """
        Batch processing for representation and agent tasks.
        Returns a tuple of (messages_context, items_to_process, configuration).
        - messages_context: unique Message rows (conversation turns) forming the context window
        - items_to_process: QueueItems for the current work_unit_key within that window
        - configuration: Resolved configuration for the batch
        """
        if task_type != "representation":
            raise ValueError(
                f"{task_type} tasks are not supported for get_queue_item_batch"
            )

        batch_max_tokens = settings.DERIVER.REPRESENTATION_BATCH_MAX_TOKENS
        parsed_key = parse_work_unit_key(work_unit_key)
        messages_context: list[models.Message] = []
        items_to_process: list[QueueItem] = []

        async with tracked_db("get_queue_item_batch") as db:
            # For batch tasks, get messages based on token limit.
            # Step 1: Verify worker still owns the work_unit_key.
            ownership_check = await db.execute(
                select(models.ActiveQueueSession.id)
                .where(models.ActiveQueueSession.work_unit_key == work_unit_key)
                .where(models.ActiveQueueSession.id == aqs_id)
            )
            if not ownership_check.scalar_one_or_none():
                return [], [], None

            # Step 2: Build a single SQL query that:
            # 1. Finds the earliest unprocessed message for this work_unit_key
            # 2. Optionally includes the preceding message if from a different peer (for context)
            # 3. Gets ALL messages from that point forward (for conversational context)
            # 4. Tracks cumulative tokens and focused sender position
            # 5. Returns empty if focused sender is beyond token limit
            # 6. Otherwise returns messages up to token limit + first focused sender message

            # Find the minimum message_id with an unprocessed queue item across the session
            min_unprocessed_message_id_subq = (
                select(func.min(models.Message.id))
                .select_from(models.QueueItem)
                .join(
                    models.Message,
                    models.QueueItem.message_id == models.Message.id,
                )
                .where(~models.QueueItem.processed)
                .where(models.Message.session_name == parsed_key.session_name)
                .where(models.Message.workspace_name == parsed_key.workspace_name)
                .where(models.QueueItem.work_unit_key == work_unit_key)
                .scalar_subquery()
            )

            # Find the immediately preceding message ID (the one right before min_unprocessed)
            immediately_preceding_id_subq = (
                select(func.max(models.Message.id))
                .where(models.Message.session_name == parsed_key.session_name)
                .where(models.Message.workspace_name == parsed_key.workspace_name)
                .where(models.Message.id < min_unprocessed_message_id_subq)
                .scalar_subquery()
            )

            # Only include the preceding message if it's from a different peer than observed
            # This provides conversational context (e.g., the question that prompted the response)
            preceding_message_id_subq = (
                select(models.Message.id)
                .where(models.Message.id == immediately_preceding_id_subq)
                .where(models.Message.peer_name != parsed_key.observed)
                .scalar_subquery()
            )

            # Determine the effective start: preceding message if it qualifies, else min_unprocessed
            # We use COALESCE to fall back to min_unprocessed if no preceding message qualifies
            effective_start_id = func.coalesce(
                preceding_message_id_subq, min_unprocessed_message_id_subq
            )

            # Build CTE with ALL messages starting from effective_start_id
            # This includes the preceding context message (if any) and interleaving messages
            cte = (
                select(
                    models.Message.id.label("message_id"),
                    models.Message.token_count.label("token_count"),
                    models.Message.peer_name.label("peer_name"),
                    func.sum(models.Message.token_count)
                    .over(order_by=models.Message.id)
                    .label("cumulative_token_count"),
                )
                .where(models.Message.session_name == parsed_key.session_name)
                .where(models.Message.workspace_name == parsed_key.workspace_name)
                .where(models.Message.id >= effective_start_id)
                .order_by(models.Message.id)
                .cte()
            )

            allowed_condition = (
                (cte.c.cumulative_token_count <= batch_max_tokens)
                | (
                    cte.c.message_id == min_unprocessed_message_id_subq
                )  # always include the first unprocessed message
            )

            query = (
                select(models.Message, models.QueueItem)
                .select_from(cte)
                .join(models.Message, models.Message.id == cte.c.message_id)
                .outerjoin(
                    models.QueueItem,
                    and_(
                        models.QueueItem.work_unit_key == work_unit_key,
                        ~models.QueueItem.processed,
                        models.QueueItem.message_id == models.Message.id,
                    ),
                )
                .where(allowed_condition)
                .order_by(models.Message.id, models.QueueItem.id)
            )

            result = await db.execute(query)
            rows = result.all()
            if not rows:
                return [], [], None

            seen_messages: set[int] = set()
            for m, qi in rows:
                if m.id not in seen_messages:
                    messages_context.append(m)
                    seen_messages.add(m.id)
                if qi is not None:
                    items_to_process.append(qi)

            _detach_queue_batch_objects(db, messages_context, items_to_process)

        items_to_process, resolved_config = _resolve_batch_configuration(
            items_to_process
        )

        if items_to_process:
            max_queue_item_message_id = max(
                qi.message_id for qi in items_to_process if qi.message_id is not None
            )
            messages_context = [
                m for m in messages_context if m.id <= max_queue_item_message_id
            ]

        return messages_context, items_to_process, resolved_config

    async def mark_queue_items_as_processed(
        self, items: list[QueueItem], work_unit_key: str
    ) -> None:
        if not items:
            return
        async with tracked_db("process_queue_item_batch") as db:
            work_unit = parse_work_unit_key(work_unit_key)
            item_ids = [item.id for item in items]
            await db.execute(
                update(models.QueueItem)
                .where(models.QueueItem.id.in_(item_ids))
                .where(models.QueueItem.work_unit_key == work_unit_key)
                .values(processed=True)
            )
            await db.execute(
                update(models.ActiveQueueSession)
                .where(models.ActiveQueueSession.work_unit_key == work_unit_key)
                .values(last_updated=func.now())
            )
            await db.commit()

            if (
                work_unit.task_type in ["representation", "summary"]
                and work_unit.workspace_name is not None
                and settings.METRICS.ENABLED
            ):
                prometheus_metrics.record_deriver_queue_item(
                    count=len(items),
                    workspace_name=work_unit.workspace_name,
                    task_type=work_unit.task_type,
                )

    async def mark_queue_item_as_errored(
        self, item: QueueItem, work_unit_key: str, error: str
    ) -> None:
        """Mark queue item as processed with an error"""
        if not item:
            return
        async with tracked_db("mark_queue_item_as_errored") as db:
            await db.execute(
                update(models.QueueItem)
                .where(models.QueueItem.id == item.id)
                .where(models.QueueItem.work_unit_key == work_unit_key)
                .values(processed=True, error=error[:65535])  # Truncate to TEXT limit
            )
            await db.execute(
                update(models.ActiveQueueSession)
                .where(models.ActiveQueueSession.work_unit_key == work_unit_key)
                .values(last_updated=func.now())
            )
            await db.commit()

    async def _cleanup_work_unit(
        self,
        aqs_id: str,
        work_unit_key: str,
    ) -> bool:
        """
        Clean up a specific work unit session by both work_unit_key and AQS ID.
        """
        async with tracked_db("cleanup_work_unit") as db:
            result = cast(
                CursorResult[Any],
                await db.execute(
                    delete(models.ActiveQueueSession)
                    .where(models.ActiveQueueSession.id == aqs_id)
                    .where(models.ActiveQueueSession.work_unit_key == work_unit_key)
                ),
            )
            await db.commit()
            return result.rowcount > 0


async def main():
    logger.debug("Starting queue manager")

    try:
        await init_cache()
    except Exception as e:
        logger.warning(
            "Error initializing cache in queue manager; proceeding without cache: %s", e
        )

    manager = QueueManager()
    try:
        await manager.initialize()
    except Exception as e:
        logger.error(f"Error in main: {str(e)}")
        sentry_sdk.capture_exception(e)
    finally:
        await close_cache()
        logger.debug("Main function exiting")
