import datetime
from collections.abc import Callable
from logging import getLogger
from typing import Any, TypeVar

from sqlalchemy import ColumnElement, Select, and_, case, cast, literal, not_, or_
from sqlalchemy.types import Numeric

from ..exceptions import FilterError
from .formatting import ILIKE_ESCAPE_CHAR, escape_ilike_pattern, parse_datetime_iso

logger = getLogger(__name__)

# Type variable for SQLAlchemy model classes
T = TypeVar("T")

# Module-level constants for comparison operators
COMPARISON_OPERATORS = {
    "gte",
    "lte",
    "gt",
    "lt",
    "ne",
    "in",
    "contains",
    "icontains",
}

NUMERIC_OPERATORS = {"gte", "lte", "gt", "lt", "ne"}

ALLOWED_EXTERNAL_TO_INTERNAL_COLUMN_MAPPING = {
    "id": "name",
    "created_at": "created_at",
    "is_active": "is_active",
    "workspace_id": "workspace_name",
    "session_id": "session_name",
    "peer_id": "peer_name",
    "metadata": "h_metadata",
}

ALLOWED_EXTERNAL_TO_INTERNAL_COLUMN_MAPPING_MESSAGES = {
    "workspace_id": "workspace_name",
    "session_id": "session_name",
    "peer_id": "peer_name",
    "token_count": "token_count",
    "created_at": "created_at",
    "metadata": "h_metadata",
}

ALLOWED_EXTERNAL_TO_INTERNAL_COLUMN_MAPPING_DOCUMENTS = {
    "session_id": "session_name",
    "workspace_id": "workspace_name",
    "observer_id": "observer",
    "observed_id": "observed",
    "metadata": "internal_metadata",
}


def apply_filter(
    stmt: Select[tuple[T]], model_class: type[T], filters: dict[str, Any] | None = None
) -> Select[tuple[T]]:
    """
    Apply advanced filter to a SQL statement based on filter dictionary.

    Supports logical operators (AND, OR, NOT), comparison operators
    (gte, lte, gt, lt, ne, contains, icontains, in), and wildcard character (*).

    Note that the filter refers to column names from the user perspective:
    that means all `*_name` fields are actually `*_id` fields and `h_metadata`
    is actually `metadata`.

    Examples:
        # Simple filters (backward compatible)
        {"peer_id": "alice", "metadata": {"type": "user"}}

        # Logical operators
        {"AND": [{"peer_id": "alice"}, {"created_at": {"gte": "2024-01-01"}}]}
        {"OR": [{"peer_id": "alice"}, {"peer_id": "bob"}]}
        {"NOT": [{"peer_id": "alice"}]}

        # Comparison operators
        {"created_at": {"gte": "2024-01-01", "lte": "2024-12-31"}}
        {"peer_id": {"in": ["alice", "bob"]}}

        # Wildcards (matches everything for that field)
        {"peer_id": "*"}

    Args:
        stmt: SQLAlchemy Select statement to modify
        model_class: SQLAlchemy model class for column access
        filters: Optional filter dictionary

    Returns:
        Modified Select statement with filter applied if provided

    Raises:
        FilterError: When the filter contains invalid configuration or values
    """
    if filters is None:
        return stmt

    conditions = _build_filter_conditions(filters, model_class)
    if conditions is not None:
        stmt = stmt.where(conditions)

    return stmt


def _build_filter_conditions(
    filter_dict: dict[str, Any], model_class: type[Any], *, _depth: int = 0
) -> ColumnElement[bool] | None:
    """
    Recursively build filter conditions from a filter dictionary.

    Args:
        filter_dict: Filter dictionary that may contain logical operators
        model_class: SQLAlchemy model class for column access

    Returns:
        SQLAlchemy condition object or None
    """
    if _depth > 5:
        raise FilterError("Filter nesting exceeds maximum depth of 5")

    conditions: list[ColumnElement[bool]] = []

    # Handle logical operators
    if "AND" in filter_dict:
        if not isinstance(filter_dict["AND"], list):
            raise FilterError(
                f"AND operator must contain a list, got {type(filter_dict['AND']).__name__}"
            )
        and_conditions: list[ColumnElement[bool]] = []
        for sub_filter in filter_dict["AND"]:  # pyright: ignore
            sub_condition = _build_filter_conditions(
                sub_filter,  # pyright: ignore[reportUnknownArgumentType]
                model_class,
                _depth=_depth + 1,
            )
            if sub_condition is not None:
                and_conditions.append(sub_condition)
        if and_conditions:
            conditions.append(and_(*and_conditions))

    if "OR" in filter_dict:
        if not isinstance(filter_dict["OR"], list):
            raise FilterError(
                f"OR operator must contain a list, got {type(filter_dict['OR']).__name__}"
            )
        or_conditions: list[ColumnElement[bool]] = []
        for sub_filter in filter_dict["OR"]:  # pyright: ignore
            sub_condition = _build_filter_conditions(
                sub_filter,  # pyright: ignore[reportUnknownArgumentType]
                model_class,
                _depth=_depth + 1,
            )
            if sub_condition is not None:
                or_conditions.append(sub_condition)
        if or_conditions:
            conditions.append(or_(*or_conditions))

    if "NOT" in filter_dict:
        if filter_dict["NOT"] is None:
            raise FilterError("NOT operator cannot be None")
        if not isinstance(filter_dict["NOT"], list):
            raise FilterError(
                f"NOT operator must contain a list, got {type(filter_dict['NOT']).__name__}"
            )
        not_conditions: list[ColumnElement[bool]] = []
        for sub_filter in filter_dict["NOT"]:  # pyright: ignore
            sub_condition = _build_filter_conditions(
                sub_filter,  # pyright: ignore[reportUnknownArgumentType]
                model_class,
                _depth=_depth + 1,
            )
            if sub_condition is not None:
                not_conditions.append(
                    not_(sub_condition)
                )  # Apply NOT to each condition individually
        if not_conditions:
            conditions.append(and_(*not_conditions))  # Then AND them together

    # Handle field-level conditions (skip logical operator keys)
    logical_keys = {"AND", "OR", "NOT"}
    for key, value in filter_dict.items():
        if key in logical_keys:
            continue

        condition = _build_field_condition(key, value, model_class)
        if condition is not None:
            conditions.append(condition)

    # Combine all conditions with AND
    if len(conditions) == 0:
        return None
    elif len(conditions) == 1:
        return conditions[0]
    else:
        return and_(*conditions)


def _build_field_condition(
    key: str, value: Any, model_class: type[Any]
) -> ColumnElement[bool] | None:
    """
    Build a condition for a single field.

    Args:
        key: Field name
        value: Field value or comparison dict
        model_class: SQLAlchemy model class

    Returns:
        SQLAlchemy condition object or None
    """
    if model_class.__name__ == "Message":
        column_name = ALLOWED_EXTERNAL_TO_INTERNAL_COLUMN_MAPPING_MESSAGES.get(key)
    elif model_class.__name__ == "Document":
        column_name = ALLOWED_EXTERNAL_TO_INTERNAL_COLUMN_MAPPING_DOCUMENTS.get(
            key,
            key,  # fallback to the key itself if not found in the mapping for internal use here
        )
    else:
        column_name = ALLOWED_EXTERNAL_TO_INTERNAL_COLUMN_MAPPING.get(key)

    if column_name is None:
        raise FilterError(
            f"Column '{key}' is not allowed to be filtered on or does not exist on {model_class.__name__}"
        )

    # Check if the column exists on the model
    if not hasattr(model_class, column_name):
        raise FilterError(f"Column '{key}' does not exist on {model_class.__name__}")

    column = getattr(model_class, column_name)

    # Handle wildcard - matches everything, so no condition needed
    if value == "*":
        return None

    # Handle comparison operators vs regular values
    if isinstance(value, dict):
        # Check if this is a comparison operators dict by looking for known operators
        is_comparison_dict = any(op_key in COMPARISON_OPERATORS for op_key in value)  # pyright: ignore

        if is_comparison_dict:
            return _build_comparison_conditions(column, column_name, value)  # pyright: ignore
        else:
            # This is a regular value that happens to be a dict
            # For JSONB fields (metadata, configuration), check if it contains nested comparison operators
            if column_name in ("h_metadata", "configuration", "internal_metadata"):
                return _build_nested_metadata_conditions(column, value)  # pyright: ignore
            else:
                return column == value
    else:
        if column_name in ("h_metadata", "configuration", "internal_metadata"):
            return column.contains(value)
        else:
            return column == value


def _safe_numeric_cast(
    column_accessor: ColumnElement[Any], op_value: Any
) -> tuple[ColumnElement[Any], Any]:
    """
    Safely cast JSONB column accessor to appropriate type for comparison.

    Args:
        column_accessor: SQLAlchemy JSONB column accessor (.astext)
        op_value: The value to compare against

    Returns:
        Tuple of (cast_column_accessor, cast_op_value) for typed comparison
        or (column_accessor, str_op_value) for string comparison
    """
    try:
        if isinstance(op_value, bool):
            # For boolean values, compare with the string representation
            # PostgreSQL JSONB stores booleans as "true"/"false" strings when extracted with ->>
            return column_accessor, str(op_value).lower()

        # For numeric values, use a safer cast that handles empty strings and invalid values
        # We use CASE WHEN to handle empty strings and non-numeric values gracefully
        safe_cast = case(
            (column_accessor == "", literal(None)),  # Empty string -> NULL
            (column_accessor.is_(None), literal(None)),  # NULL -> NULL
            else_=cast(column_accessor, Numeric()),
        )

        if isinstance(op_value, int | float):
            return safe_cast, op_value
        else:
            # Try to parse as numeric (handles both strings and other types)
            try:
                # Try int first, then float
                parsed_value = int(op_value)
                return safe_cast, parsed_value
            except (ValueError, TypeError):
                try:
                    parsed_value = float(op_value)
                    return safe_cast, parsed_value
                except (ValueError, TypeError):
                    if isinstance(op_value, str):
                        # If it's not numeric, treat as string comparison (e.g., dates, text)
                        # This allows date strings like "2024-02-01" to be compared lexicographically
                        return column_accessor, str(op_value)
                    else:
                        raise FilterError(
                            f"Invalid value for numeric operator: {op_value}. Expected a number, got {type(op_value).__name__}"
                        ) from None
    except Exception as e:
        raise FilterError(
            f"Failed to process numeric cast for value '{op_value}': {str(e)}"
        ) from e


def _build_comparison_condition(
    column: Any, field_name: str, operator: str, op_value: Any
) -> ColumnElement[bool] | None:
    """
    Build a single comparison condition for a JSONB field.

    Args:
        column: SQLAlchemy JSONB column object
        field_name: Name of the field in the JSONB column
        operator: Comparison operator
        op_value: Value to compare against

    Returns:
        SQLAlchemy condition object or None
    """
    # Validate that the operator is supported
    if operator not in COMPARISON_OPERATORS:
        raise FilterError(f"Unsupported comparison operator: {operator}")

    # Handle wildcard - matches everything, so no condition needed
    if op_value == "*":
        return None

    field_accessor = column[field_name].astext

    # Mapping of operators to their SQLAlchemy methods
    if operator in NUMERIC_OPERATORS:
        try:
            safe_accessor, safe_value = _safe_numeric_cast(field_accessor, op_value)
            operator_map: dict[str, Callable[[Any, Any], ColumnElement[bool]]] = {
                "gte": lambda a, v: a >= v,
                "lte": lambda a, v: a <= v,
                "gt": lambda a, v: a > v,
                "lt": lambda a, v: a < v,
                "ne": lambda a, v: a != v,
            }
            return operator_map[operator](safe_accessor, safe_value)
        except Exception as e:
            raise FilterError(
                f"Failed to build numeric comparison condition for operator '{operator}' with value '{op_value}': {str(e)}"
            ) from e
    elif operator == "in":
        if hasattr(op_value, "__iter__") and not isinstance(op_value, str | bytes):
            # Handle wildcard in iterable - if present, matches everything, so no condition needed
            if "*" in op_value:
                return None
            return field_accessor.in_([str(v) for v in op_value])
        else:
            raise FilterError(
                f"Invalid value for 'in' operator: {op_value}. Expected an iterable (list, tuple, set), got {type(op_value).__name__}"
            )
    elif operator in ("contains", "icontains"):
        escaped_value = escape_ilike_pattern(str(op_value))
        return field_accessor.ilike(f"%{escaped_value}%", escape=ILIKE_ESCAPE_CHAR)

    return None


def _build_nested_metadata_conditions(
    column: Any, metadata_dict: dict[str, Any]
) -> ColumnElement[bool] | None:
    """
    Build conditions for nested metadata fields with comparison operators.

    Args:
        column: SQLAlchemy JSONB column object
        metadata_dict: Dictionary containing nested field conditions

    Returns:
        Combined SQLAlchemy condition object or None
    """
    conditions: list[ColumnElement[bool]] = []

    for field_name, field_value in metadata_dict.items():
        if isinstance(field_value, dict) and any(
            op in COMPARISON_OPERATORS
            for op in field_value  # pyright: ignore
        ):
            # This field has comparison operators
            field_conditions: list[ColumnElement[bool]] = []
            for operator, op_value in field_value.items():  # pyright: ignore
                condition = _build_comparison_condition(
                    column,
                    field_name,
                    operator,  # pyright: ignore
                    op_value,
                )
                if condition is not None:
                    field_conditions.append(condition)

            if field_conditions:
                conditions.append(
                    field_conditions[0]
                    if len(field_conditions) == 1
                    else and_(*field_conditions)
                )
        else:
            # Handle wildcard - matches everything, so no condition needed
            if field_value == "*":
                continue
            # Regular field equality - use JSONB contains for nested object matching
            conditions.append(column.contains({field_name: field_value}))

    # Combine all field conditions with AND
    return _combine_conditions_with_and(conditions)


def _combine_conditions_with_and(
    conditions: list[ColumnElement[bool]],
) -> ColumnElement[bool] | None:
    """
    Combine a list of conditions with AND logic.

    Args:
        conditions: List of SQLAlchemy condition objects

    Returns:
        Combined condition object or None if no conditions
    """
    if not conditions:
        return None
    elif len(conditions) == 1:
        return conditions[0]
    else:
        return and_(*conditions)


def _build_comparison_conditions(
    column: Any, column_name: str, comparisons: dict[str, Any]
) -> ColumnElement[bool] | None:
    """
    Build comparison conditions for a single column.

    Args:
        column: SQLAlchemy column object
        column_name: Name of the column
        comparisons: Dictionary of comparison operators and values

    Returns:
        Combined SQLAlchemy condition object or None
    """
    conditions: list[ColumnElement[bool]] = []

    # Check if this is a datetime column
    is_datetime_column = hasattr(column.type, "python_type") and issubclass(
        column.type.python_type, datetime.datetime
    )

    for operator, op_value in comparisons.items():
        # Validate that the operator is supported
        if operator not in COMPARISON_OPERATORS:
            raise FilterError(f"Unsupported comparison operator: {operator}")

        # Handle wildcard - matches everything, so no condition needed
        if op_value == "*":
            continue

        condition = None

        # For datetime columns, cast string values to timestamp
        if is_datetime_column and isinstance(op_value, str):
            # Validate datetime string to prevent SQL injection
            validated_datetime = _validate_datetime_string(op_value)
            if validated_datetime is None:
                # Raise error if datetime validation fails
                raise FilterError(f"Invalid datetime value: {op_value}")

            # Use the validated datetime object directly instead of string interpolation
            casted_value = validated_datetime
        else:
            # if the operator is a numeric operator, the value must cast to a number
            if operator in NUMERIC_OPERATORS:
                try:
                    casted_value = float(op_value)
                except ValueError:
                    raise FilterError(
                        f"Invalid numeric value: {op_value}. Expected a number, got {type(op_value).__name__}"
                    ) from None
            else:
                casted_value = op_value

        if operator == "gte":
            condition = column >= casted_value
        elif operator == "lte":
            condition = column <= casted_value
        elif operator == "gt":
            condition = column > casted_value
        elif operator == "lt":
            condition = column < casted_value
        elif operator == "ne":
            condition = column != casted_value
        elif operator == "in":
            if hasattr(op_value, "__iter__") and not isinstance(op_value, str | bytes):
                # Handle wildcard in iterable - if present, matches everything, so no condition needed
                if "*" in op_value:
                    continue
                else:
                    if is_datetime_column:
                        # Validate and cast each datetime string value
                        casted_values: list[str | datetime.datetime] = []
                        for val in op_value:
                            if isinstance(val, str):
                                validated_datetime = _validate_datetime_string(val)
                                if validated_datetime is None:
                                    raise FilterError(
                                        f"Invalid datetime value in list: {val}"
                                    )
                                casted_values.append(validated_datetime)
                            else:
                                casted_values.append(val)
                        if casted_values:
                            condition = column.in_(casted_values)
                    else:
                        condition = column.in_(list(op_value))
            else:
                raise FilterError(
                    f"Invalid value for 'in' operator: {op_value}. Expected an iterable (list, tuple, set), got {type(op_value).__name__}"
                )
        elif operator == "contains":
            if column_name == "h_metadata":
                # For JSONB columns, use JSONB contains
                condition = column.contains(op_value)
            else:
                # For text columns, use ILIKE with escaped pattern
                escaped_value = escape_ilike_pattern(str(op_value))
                condition = column.ilike(f"%{escaped_value}%", escape=ILIKE_ESCAPE_CHAR)
        elif operator == "icontains":
            # Case-insensitive contains for text columns with escaped pattern
            escaped_value = escape_ilike_pattern(str(op_value))
            condition = column.ilike(f"%{escaped_value}%", escape=ILIKE_ESCAPE_CHAR)

        if condition is not None:
            conditions.append(condition)

    # Combine all conditions for this field with AND
    if len(conditions) == 0:
        return None
    elif len(conditions) == 1:
        return conditions[0]
    else:
        return and_(*conditions)


def _validate_datetime_string(value: str) -> datetime.datetime | None:
    """
    Safely validate and parse a datetime string to prevent SQL injection.

    This function prioritizes timezone-aware datetime formats and uses the
    consistent parse_datetime_iso utility for proper timezone handling.

    Args:
        value: String value to validate as datetime

    Returns:
        Parsed datetime object if valid, None if invalid
    """
    # Strip whitespace
    value = value.strip()

    # First try the standard ISO format with timezone info using our utility
    try:
        return parse_datetime_iso(value)
    except ValueError:
        pass

    # Fallback to naive formats (assume UTC timezone for compatibility)
    naive_formats = [
        "%Y-%m-%dT%H:%M:%S",  # 2024-01-01T12:00:00 (ISO format, assume UTC)
        "%Y-%m-%dT%H:%M:%S.%f",  # 2024-01-01T12:00:00.123456 (assume UTC)
        "%Y-%m-%d %H:%M:%S",  # 2024-01-01 12:00:00 (assume UTC)
        "%Y-%m-%d %H:%M:%S.%f",  # 2024-01-01 12:00:00.123456 (assume UTC)
        "%Y-%m-%d",  # 2024-01-01 (assume UTC, start of day)
    ]

    for fmt in naive_formats:
        try:
            parsed = datetime.datetime.strptime(value, fmt)
            # Assume UTC timezone for naive datetimes
            return parsed.replace(tzinfo=datetime.timezone.utc)
        except ValueError:
            continue

    # Return None for invalid datetime - let the caller handle the error
    return None
