from flask import Blueprint, request, jsonify
from flask_jwt_extended import jwt_required, get_jwt_identity
from db.db import get_db_connection
from config.auth import role_required
import mysql.connector
import calendar
from datetime import datetime, date, timedelta

session_bp = Blueprint('session', __name__)


# ─────────────────────────────────────────────────────────────
# 🛠️  HELPERS
# ─────────────────────────────────────────────────────────────

def _td_to_hhmm(val) -> str:
    """
    Convert a MySQL TIME column (timedelta or string) → "HH:MM".
    Handles: timedelta, "HH:MM:SS", "HH:MM", None.
    """
    if val is None:
        return None
    if isinstance(val, timedelta):
        total = int(val.total_seconds())
        h, rem = divmod(total, 3600)
        m, _   = divmod(rem, 60)
        return f"{h:02d}:{m:02d}"
    if isinstance(val, str):
        parts = val.split(':')
        if len(parts) >= 2:
            return f"{int(parts[0]):02d}:{parts[1][:2]}"
        return val
    return str(val)


def _serialize_row(row: dict) -> dict:
    """
    Convert datetime / date / timedelta / Decimal → JSON-safe types.
    TIME columns (timedelta) → "HH:MM" via _td_to_hhmm().
    """
    for key, val in list(row.items()):
        if isinstance(val, datetime):
            row[key] = val.strftime('%Y-%m-%d %H:%M:%S')
        elif isinstance(val, date):
            row[key] = val.isoformat()
        elif isinstance(val, timedelta):
            row[key] = _td_to_hhmm(val)
        elif hasattr(val, '__float__'):  # Decimal
            row[key] = float(val)
    return row


def _compute_session_status(s: dict) -> str:
    """
    Derive display status from DB status + booked/max ratio.

    DB values  : 'active', 'cancelled', 'completed'
    Display    : 'Open', 'Closed', 'Full'
    """
    db_status = (s.get('status') or 'active').lower()

    if db_status in ('cancelled', 'completed'):
        return 'Closed'

    booked  = int(s.get('booked_count') or 0)
    max_pat = s.get('max_patients')

    if max_pat is not None:
        try:
            max_pat = int(max_pat)
        except (ValueError, TypeError):
            max_pat = None

    if max_pat and booked >= max_pat:
        return 'Full'

    return 'Open'


# ─────────────────────────────────────────────────────────────
# GET /all_sessions
# ─────────────────────────────────────────────────────────────
@session_bp.route('/all_sessions', methods=['GET'])
@jwt_required()
def get_all_sessions():
    store_id  = request.args.get('store_id',  type=int)
    doctor_id = request.args.get('doctor_id', type=int)
    date_str  = request.args.get('date',   '').strip()
    status_f  = request.args.get('status', '').strip()

    conn   = None
    cursor = None
    try:
        conn   = get_db_connection()
        cursor = conn.cursor(dictionary=True)

        query = """
            SELECT
                cs.id,
                cs.doctor_id,
                cs.store_id,
                cs.session_date,
                cs.session_start,
                cs.session_end,
                cs.max_patients,
                cs.session_no,
                cs.status,
                cs.created_at,
                d.name           AS doctor_name,
                d.reg_no,
                d.specialization,
                d.channeling_fee,
                s.store_name,
                COUNT(a.id)      AS booked_count
            FROM channeling_sessions cs
            LEFT JOIN doctors d ON cs.doctor_id = d.id
            LEFT JOIN stores  s ON cs.store_id  = s.id
            LEFT JOIN appointments a
                   ON a.session_id = cs.id
                  AND a.status != 'cancelled'
            WHERE 1=1
        """
        params = []

        if store_id:
            query += " AND cs.store_id = %s"
            params.append(store_id)
        if doctor_id:
            query += " AND cs.doctor_id = %s"
            params.append(doctor_id)
        if date_str:
            query += " AND cs.session_date = %s"
            params.append(date_str)

        query += " GROUP BY cs.id ORDER BY cs.session_date DESC, cs.session_start ASC"

        cursor.execute(query, params)
        rows = cursor.fetchall()

        sessions = []
        for row in rows:
            _serialize_row(row)
            row['start_time'] = row.pop('session_start', None)
            row['end_time']   = row.pop('session_end',   None)
            row['status']     = _compute_session_status(row)

            if status_f and row['status'] != status_f:
                continue

            sessions.append(row)

        return jsonify({"success": True, "sessions": sessions}), 200

    except mysql.connector.Error as e:
        print(f"[DB ERROR] get_all_sessions: {e}")
        return jsonify({"success": False, "error": "Database error while fetching sessions."}), 500
    finally:
        if cursor: cursor.close()
        if conn:   conn.close()


# ─────────────────────────────────────────────────────────────
# GET /get_session/<id>
# ─────────────────────────────────────────────────────────────
@session_bp.route('/get_session/<int:session_id>', methods=['GET'])
@jwt_required()
def get_session(session_id):
    conn   = None
    cursor = None
    try:
        conn   = get_db_connection()
        cursor = conn.cursor(dictionary=True)

        cursor.execute(
            """
            SELECT
                cs.id, cs.doctor_id, cs.store_id, cs.session_date,
                cs.session_start, cs.session_end, cs.max_patients,
                cs.session_no, cs.status, cs.created_at,
                d.name AS doctor_name, d.reg_no, d.specialization,
                d.channeling_fee, s.store_name,
                COUNT(a.id) AS booked_count
            FROM channeling_sessions cs
            LEFT JOIN doctors d ON cs.doctor_id = d.id
            LEFT JOIN stores  s ON cs.store_id  = s.id
            LEFT JOIN appointments a
                   ON a.session_id = cs.id AND a.status != 'cancelled'
            WHERE cs.id = %s
            GROUP BY cs.id
            """,
            (session_id,)
        )
        row = cursor.fetchone()
        if not row:
            return jsonify({"success": False, "error": "Session not found."}), 404

        _serialize_row(row)
        row['start_time'] = row.pop('session_start', None)
        row['end_time']   = row.pop('session_end',   None)
        row['status']     = _compute_session_status(row)

        return jsonify({"success": True, "session": row}), 200

    except mysql.connector.Error as e:
        print(f"[DB ERROR] get_session: {e}")
        return jsonify({"success": False, "error": "Database error."}), 500
    finally:
        if cursor: cursor.close()
        if conn:   conn.close()


# ─────────────────────────────────────────────────────────────
# POST /add_session
# ─────────────────────────────────────────────────────────────
@session_bp.route('/add_session', methods=['POST'])
@jwt_required()
def add_session():
    data = request.get_json()
    if not data:
        return jsonify({"success": False, "error": "No data provided."}), 400

    doctor_id      = data.get('doctor_id')
    session_date   = data.get('session_date')
    session_no     = data.get('session_no')
    start_time     = (data.get('start_time') or '').strip()
    end_time       = (data.get('end_time')   or '').strip()
    max_patients   = data.get('max_patients')
    status_in      = (data.get('status') or 'Open').strip()
    # NOTE: 'notes' is not a column in channeling_sessions — intentionally not read.

    # ── Required field validation ──────────────────────────────
    if not doctor_id:
        return jsonify({"success": False, "error": "doctor_id is required."}), 422
    if not session_date:
        return jsonify({"success": False, "error": "session_date is required."}), 422
    if not session_no:
        return jsonify({"success": False, "error": "session_no is required."}), 422
    if not start_time:
        return jsonify({"success": False, "error": "start_time is required."}), 422
    if not end_time:
        return jsonify({"success": False, "error": "end_time is required."}), 422
    if start_time >= end_time:
        return jsonify({"success": False, "error": "end_time must be after start_time."}), 422

    # ── Past date validation ───────────────────────────────────
    try:
        session_date_obj = datetime.strptime(session_date, '%Y-%m-%d').date()
    except ValueError:
        return jsonify({"success": False, "error": "Invalid session_date format. Use YYYY-MM-DD."}), 422

    if session_date_obj < date.today():
        return jsonify({"success": False, "error": "Session date cannot be in the past."}), 422
    # ──────────────────────────────────────────────────────────

    db_status = 'active' if status_in == 'Open' else 'cancelled'

    conn   = None
    cursor = None
    try:
        conn   = get_db_connection()
        cursor = conn.cursor(dictionary=True)

        cursor.execute(
            "SELECT id, store_id FROM doctors WHERE id = %s AND status = 'Active'",
            (doctor_id,)
        )
        doctor = cursor.fetchone()
        if not doctor:
            return jsonify({"success": False, "error": "Doctor not found or inactive."}), 404

        store_id = doctor['store_id']
        if not store_id:
            return jsonify({"success": False, "error": "Doctor is not assigned to any branch."}), 400

        if max_patients is not None:
            try:
                max_patients = int(max_patients)
                if max_patients < 1:
                    raise ValueError
            except (ValueError, TypeError):
                return jsonify({"success": False, "error": "max_patients must be a positive integer."}), 422

        cursor.execute(
            """
            SELECT id FROM channeling_sessions
            WHERE doctor_id = %s AND store_id = %s
              AND session_date = %s AND session_no = %s
            LIMIT 1
            """,
            (doctor_id, store_id, session_date, session_no)
        )
        if cursor.fetchone():
            return jsonify({"success": False,
                            "error": f"Session {session_no} for this doctor on {session_date} already exists."}), 409

        cursor.execute(
            """
            INSERT INTO channeling_sessions
                (doctor_id, store_id, session_date, session_no,
                 session_start, session_end, max_patients, status, created_at)
            VALUES (%s, %s, %s, %s, %s, %s, %s, %s, NOW())
            """,
            (int(doctor_id), store_id, session_date, int(session_no),
             start_time, end_time, max_patients, db_status)
        )
        conn.commit()
        new_id = cursor.lastrowid

        return jsonify({
            "success":    True,
            "message":    f"Session created successfully for {session_date}.",
            "session_id": new_id,
        }), 201

    except mysql.connector.Error as e:
        if conn: conn.rollback()
        print(f"[DB ERROR] add_session: {e}")
        return jsonify({"success": False, "error": "Database error while creating session."}), 500
    finally:
        if cursor: cursor.close()
        if conn:   conn.close()


# ─────────────────────────────────────────────────────────────
# PUT /update_session/<id>
# FIX: Removed channeling_fee — it does NOT exist in channeling_sessions table.
#      channeling_fee lives on the doctors table, not sessions.
# ─────────────────────────────────────────────────────────────
@session_bp.route('/update_session/<int:session_id>', methods=['PUT'])
@jwt_required()
def update_session(session_id):
    data = request.get_json()
    if not data:
        return jsonify({"success": False, "error": "No data provided."}), 400

    conn   = None
    cursor = None
    try:
        conn   = get_db_connection()
        cursor = conn.cursor(dictionary=True)

        cursor.execute("SELECT id, status FROM channeling_sessions WHERE id = %s", (session_id,))
        if not cursor.fetchone():
            return jsonify({"success": False, "error": "Session not found."}), 404

        start_time = data.get('start_time')
        end_time   = data.get('end_time')
        if start_time and end_time and start_time >= end_time:
            return jsonify({"success": False, "error": "end_time must be after start_time."}), 422

        # Status mapping: frontend display value → DB enum value
        status_map = {'Open': 'active', 'Closed': 'cancelled', 'Full': 'active'}

        set_clauses = []
        params      = []

        if 'session_date' in data and data['session_date']:
            set_clauses.append("session_date = %s")
            params.append(data['session_date'])

        if start_time:
            set_clauses.append("session_start = %s")
            params.append(start_time)

        if end_time:
            set_clauses.append("session_end = %s")
            params.append(end_time)

        if 'max_patients' in data:
            mp = data['max_patients']
            if mp is None or mp == '':
                set_clauses.append("max_patients = NULL")
            else:
                try:
                    mp = int(mp)
                    if mp < 1: raise ValueError
                    set_clauses.append("max_patients = %s")
                    params.append(mp)
                except (ValueError, TypeError):
                    return jsonify({"success": False, "error": "max_patients must be a positive integer."}), 422

        if 'status' in data and data['status']:
            set_clauses.append("status = %s")
            params.append(status_map.get(data['status'], 'active'))

        # NOTE: channeling_fee and notes are intentionally excluded.
        # channeling_sessions table columns: id, doctor_id, store_id, session_date,
        # session_no, session_start, session_end, max_patients, status, created_at — nothing else.

        if not set_clauses:
            return jsonify({"success": False, "error": "No valid fields to update."}), 400

        params.append(session_id)
        cursor.execute(
            f"UPDATE channeling_sessions SET {', '.join(set_clauses)} WHERE id = %s",
            params
        )
        conn.commit()

        return jsonify({"success": True, "message": "Session updated successfully.", "session_id": session_id}), 200

    except mysql.connector.Error as e:
        if conn: conn.rollback()
        print(f"[DB ERROR] update_session: {e}")
        return jsonify({"success": False, "error": "Database error while updating session."}), 500
    finally:
        if cursor: cursor.close()
        if conn:   conn.close()


# ─────────────────────────────────────────────────────────────
# DELETE /delete_session/<id>
# ─────────────────────────────────────────────────────────────
@session_bp.route('/delete_session/<int:session_id>', methods=['DELETE'])
@jwt_required()
def delete_session(session_id):
    conn   = None
    cursor = None
    try:
        conn   = get_db_connection()
        cursor = conn.cursor(dictionary=True)

        cursor.execute("SELECT id FROM channeling_sessions WHERE id = %s", (session_id,))
        if not cursor.fetchone():
            return jsonify({"success": False, "error": "Session not found."}), 404

        cursor.execute(
            "SELECT COUNT(*) AS cnt FROM appointments WHERE session_id = %s AND status != 'cancelled'",
            (session_id,)
        )
        row = cursor.fetchone()
        if row and int(row['cnt']) > 0:
            return jsonify({
                "success": False,
                "error": (
                    f"Cannot delete this session — {row['cnt']} active appointment(s) are linked. "
                    "Cancel those appointments first, or mark the session as Closed instead."
                )
            }), 409

        cursor.execute("DELETE FROM channeling_sessions WHERE id = %s", (session_id,))
        conn.commit()

        return jsonify({"success": True, "message": "Session deleted successfully.", "session_id": session_id}), 200

    except mysql.connector.Error as e:
        if conn: conn.rollback()
        print(f"[DB ERROR] delete_session: {e}")
        return jsonify({"success": False, "error": "Database error while deleting session."}), 500
    finally:
        if cursor: cursor.close()
        if conn:   conn.close()


# ─────────────────────────────────────────────────────────────
# GET /preview_auto_sessions
# ─────────────────────────────────────────────────────────────
@session_bp.route('/preview_auto_sessions', methods=['GET'])
@jwt_required()
def preview_auto_sessions():
    month_str = request.args.get('month',     '').strip()
    doctor_id = request.args.get('doctor_id', type=int)
    store_id  = request.args.get('store_id',  type=int)

    today = date.today()
    if month_str:
        try:
            year, mon = [int(x) for x in month_str.split('-')]
            target_month = date(year, mon, 1)
        except (ValueError, AttributeError):
            return jsonify({"success": False, "error": "Invalid month format. Use YYYY-MM."}), 422
    else:
        target_month = date(today.year, today.month, 1)

    year = target_month.year
    mon  = target_month.month
    days_in_month = calendar.monthrange(year, mon)[1]
    all_dates = [date(year, mon, d) for d in range(1, days_in_month + 1)]
    day_names = ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']

    conn   = None
    cursor = None
    try:
        conn   = get_db_connection()
        cursor = conn.cursor(dictionary=True)

        doc_filter = ""
        doc_params = []
        if doctor_id:
            doc_filter += " AND d.id = %s"
            doc_params.append(doctor_id)
        if store_id:
            doc_filter += " AND d.store_id = %s"
            doc_params.append(store_id)

        cursor.execute(
            f"""
            SELECT
                d.id           AS doctor_id,
                d.name         AS doctor_name,
                d.store_id,
                d.channeling_fee,
                ds.day,
                ds.session_no,
                ds.start_time,
                ds.end_time,
                ds.max_patients
            FROM doctors d
            JOIN doctor_schedules ds ON ds.doctor_id = d.id
            WHERE d.status = 'Active'
              AND d.store_id IS NOT NULL
              {doc_filter}
            ORDER BY d.id, ds.day, ds.session_no
            """,
            doc_params
        )
        schedules = cursor.fetchall()

        if not schedules:
            return jsonify({
                "success":    True,
                "month":      target_month.strftime('%B %Y'),
                "preview":    [],
                "new_count":  0,
                "skip_count": 0,
            }), 200

        # ── Existing sessions for this month ──────────────────
        cursor.execute(
            """
            SELECT doctor_id, session_date, session_no
            FROM channeling_sessions
            WHERE YEAR(session_date) = %s AND MONTH(session_date) = %s
            """,
            (year, mon)
        )
        existing_keys = set()
        for r in cursor.fetchall():
            sd = r['session_date']
            sd_str = sd.isoformat() if hasattr(sd, 'isoformat') else str(sd)[:10]
            existing_keys.add((int(r['doctor_id']), sd_str, int(r['session_no'])))

        # ── Build preview (skip past dates) ───────────────────
        preview    = []
        new_count  = 0
        skip_count = 0

        for sch in schedules:
            sched_day = sch['day']
            for cal_date in all_dates:
                if day_names[cal_date.weekday()] != sched_day:
                    continue

                key         = (int(sch['doctor_id']), cal_date.isoformat(), int(sch['session_no']))
                is_past     = cal_date < today
                is_existing = key in existing_keys

                preview.append({
                    "doctor_id":      int(sch['doctor_id']),
                    "doctor_name":    sch['doctor_name'],
                    "session_date":   cal_date.isoformat(),
                    "day":            sched_day,
                    "session_no":     int(sch['session_no']),
                    "start_time":     _td_to_hhmm(sch['start_time']),
                    "end_time":       _td_to_hhmm(sch['end_time']),
                    "max_patients":   sch['max_patients'],
                    "channeling_fee": float(sch['channeling_fee'] or 0),
                    "exists":         is_existing,
                    "is_past":        is_past,
                })

                if is_existing or is_past:
                    skip_count += 1
                else:
                    new_count += 1

        return jsonify({
            "success":    True,
            "month":      target_month.strftime('%B %Y'),
            "preview":    preview,
            "new_count":  new_count,
            "skip_count": skip_count,
        }), 200

    except mysql.connector.Error as e:
        print(f"[DB ERROR] preview_auto_sessions: {e}")
        return jsonify({"success": False, "error": "Database error during preview."}), 500
    finally:
        if cursor: cursor.close()
        if conn:   conn.close()


# ─────────────────────────────────────────────────────────────
# POST /auto_create_sessions
# ─────────────────────────────────────────────────────────────
@session_bp.route('/auto_create_sessions', methods=['POST'])
@jwt_required()
def auto_create_sessions():
    data      = request.get_json() or {}
    month_str = (data.get('month') or '').strip()
    doctor_id = data.get('doctor_id')
    store_id  = data.get('store_id')

    today = date.today()
    if month_str:
        try:
            year, mon = [int(x) for x in month_str.split('-')]
            target_month = date(year, mon, 1)
        except (ValueError, AttributeError):
            return jsonify({"success": False, "error": "Invalid month format. Use YYYY-MM."}), 422
    else:
        target_month = date(today.year, today.month, 1)

    year = target_month.year
    mon  = target_month.month
    days_in_month = calendar.monthrange(year, mon)[1]
    all_dates = [date(year, mon, d) for d in range(1, days_in_month + 1)]
    day_names = ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']

    conn   = None
    cursor = None
    try:
        conn   = get_db_connection()
        cursor = conn.cursor(dictionary=True)

        doc_filter = ""
        doc_params = []
        if doctor_id:
            doc_filter += " AND d.id = %s"
            doc_params.append(int(doctor_id))
        if store_id:
            doc_filter += " AND d.store_id = %s"
            doc_params.append(int(store_id))

        cursor.execute(
            f"""
            SELECT
                d.id           AS doctor_id,
                d.store_id,
                d.channeling_fee,
                ds.day,
                ds.session_no,
                ds.start_time,
                ds.end_time,
                ds.max_patients
            FROM doctors d
            JOIN doctor_schedules ds ON ds.doctor_id = d.id
            WHERE d.status = 'Active'
              AND d.store_id IS NOT NULL
              {doc_filter}
            ORDER BY d.id, ds.day, ds.session_no
            """,
            doc_params
        )
        schedules = cursor.fetchall()

        if not schedules:
            return jsonify({
                "success": False,
                "error": (
                    "No active doctor schedules found. "
                    "Check that doctors are Active, have a store assigned, "
                    "and have schedules configured in doctor_schedules."
                )
            }), 404

        created = 0
        skipped = 0
        errors  = []

        for sch in schedules:
            for cal_date in all_dates:
                if day_names[cal_date.weekday()] != sch['day']:
                    continue

                # Skip past dates
                if cal_date < today:
                    skipped += 1
                    continue

                start_str = _td_to_hhmm(sch['start_time'])
                end_str   = _td_to_hhmm(sch['end_time'])
                max_pat   = sch['max_patients']

                if not start_str or not end_str:
                    errors.append(
                        f"{cal_date} doc#{sch['doctor_id']} sess#{sch['session_no']}: "
                        f"Missing time values (start={start_str}, end={end_str})"
                    )
                    skipped += 1
                    continue

                try:
                    cursor.execute(
                        """
                        INSERT IGNORE INTO channeling_sessions
                            (doctor_id, store_id, session_date, session_no,
                             session_start, session_end, max_patients,
                             status, created_at)
                        VALUES (%s, %s, %s, %s, %s, %s, %s, 'active', NOW())
                        """,
                        (
                            int(sch['doctor_id']),
                            int(sch['store_id']),
                            cal_date.isoformat(),
                            int(sch['session_no']),
                            start_str,
                            end_str,
                            max_pat,
                        )
                    )
                    if cursor.rowcount == 1:
                        created += 1
                    else:
                        skipped += 1

                except mysql.connector.Error as row_err:
                    err_msg = (
                        f"{cal_date} doc#{sch['doctor_id']} "
                        f"sess#{sch['session_no']}: {row_err}"
                    )
                    print(f"[AUTO-CREATE ROW ERROR] {err_msg}")
                    errors.append(err_msg)
                    skipped += 1

        conn.commit()

        month_label = target_month.strftime('%B %Y')
        print(f"[AUTO-CREATE] {month_label}: created={created} skipped={skipped} errors={len(errors)}")

        return jsonify({
            "success": True,
            "message": (
                f"Auto-create complete for {month_label}. "
                f"{created} session(s) created, {skipped} skipped (past dates or already existed)."
            ),
            "created": created,
            "skipped": skipped,
            "errors":  errors[:10],
        }), 201

    except mysql.connector.Error as e:
        if conn: conn.rollback()
        print(f"[DB ERROR] auto_create_sessions: {e}")
        return jsonify({"success": False, "error": "Database error during auto-create."}), 500
    finally:
        if cursor: cursor.close()
        if conn:   conn.close()