from flask import Blueprint, request, jsonify
from flask_jwt_extended import jwt_required, get_jwt_identity
from db.db import get_db_connection
from config.auth import role_required
from datetime import datetime, timedelta
import json
from decimal import Decimal
from datetime import date, datetime

sale_bp = Blueprint('sale', __name__)

# ============================================
# WAREHOUSE-SPECIFIC STOCK MANAGEMENT HELPERS
# ============================================

def get_available_stock_by_warehouse(cursor, product_id, variation_id, warehouse_id, price=None):
    """
    Get available stock for a specific warehouse from batches.
    ✅ PRICE-AWARE: Can filter by pb.price to get stock at specific price point

    Args:
        cursor: Database cursor
        product_id: Product ID
        variation_id: Variation ID (can be None)
        warehouse_id: Warehouse ID to check stock
        price: Optional price filter (only count stock at this price)

    Returns:
        float: Available stock quantity in the specified warehouse
    """
    if not warehouse_id:
        return 0.0

    # Build query with optional price filter
    if variation_id:
        query = """
            SELECT IFNULL(SUM(
                LEAST(pb.remaining_quantity, ws.quantity)
            ), 0) as available_stock
            FROM product_batches pb
            INNER JOIN warehouse_stock ws ON pb.batch_id = ws.batch_id
            WHERE pb.variation_id = %s 
            AND ws.warehouse_id = %s
            AND pb.remaining_quantity > 0
            AND ws.quantity > 0
        """
        params = [variation_id, warehouse_id]

        if price is not None:
            # ✅ FIXED: our_price column dropped — filter by pb.price
            query += " AND pb.price = %s"
            params.append(float(price))

        cursor.execute(query, tuple(params))
    else:
        query = """
            SELECT IFNULL(SUM(
                LEAST(pb.remaining_quantity, ws.quantity)
            ), 0) as available_stock
            FROM product_batches pb
            INNER JOIN warehouse_stock ws ON pb.batch_id = ws.batch_id
            WHERE pb.product_id = %s 
            AND (pb.variation_id IS NULL OR pb.variation_id = 0)
            AND ws.warehouse_id = %s
            AND pb.remaining_quantity > 0
            AND ws.quantity > 0
        """
        params = [product_id, warehouse_id]

        if price is not None:
            # ✅ FIXED: our_price column dropped — filter by pb.price
            query += " AND pb.price = %s"
            params.append(float(price))

        cursor.execute(query, tuple(params))

    result = cursor.fetchone()
    return float(result['available_stock']) if result else 0.0


def deduct_stock_from_batches_fifo_warehouse(cursor, product_id, variation_id, quantity, warehouse_id, store_id, sale_product_id, price=None):
    """
    Deduct stock from product_batches using FIFO for a specific warehouse.
    Also updates warehouse_stock table and creates sale_product_items entries.

    ✅ PRICE-AWARE: Only deducts from batches matching the selected price
    ✅ INCLUDES ROW LOCKING to prevent race conditions

    Args:
        cursor: Database cursor
        product_id: Product ID
        variation_id: Variation ID (can be None)
        quantity: Quantity to deduct
        warehouse_id: Warehouse ID
        store_id: Store ID
        sale_product_id: ID from sale_products table (for linking sale_product_items)
        price: Expected selling price (pb.price — filters batches to match this price)

    Returns:
        list: List of dictionaries containing batch details used

    Raises:
        Exception: If insufficient stock at the specified price
    """
    if not warehouse_id:
        raise Exception(f"Warehouse ID is required for product {product_id}")

    remaining_qty = float(quantity)
    batches_used  = []

    # Build query with price filter and row locking
    if variation_id:
        query = """
            SELECT pb.batch_id, pb.grn_id, pb.remaining_quantity,
                   pb.cost, pb.price, pb.expiration_date,
                   ws.id as warehouse_stock_id, ws.quantity as ws_qty
            FROM product_batches pb
            INNER JOIN warehouse_stock ws ON pb.batch_id = ws.batch_id
            WHERE pb.variation_id = %s
            AND pb.remaining_quantity > 0
            AND ws.warehouse_id = %s
            AND ws.quantity > 0
        """
        params = [variation_id, warehouse_id]

        if price is not None:
            # ✅ FIXED: our_price column dropped — filter by pb.price
            query += " AND pb.price = %s"
            params.append(float(price))

        query += " ORDER BY pb.batch_id ASC FOR UPDATE"
        cursor.execute(query, tuple(params))
    else:
        query = """
            SELECT pb.batch_id, pb.grn_id, pb.remaining_quantity,
                   pb.cost, pb.price, pb.expiration_date,
                   ws.id as warehouse_stock_id, ws.quantity as ws_qty
            FROM product_batches pb
            INNER JOIN warehouse_stock ws ON pb.batch_id = ws.batch_id
            WHERE pb.product_id = %s
            AND (pb.variation_id IS NULL OR pb.variation_id = 0)
            AND pb.remaining_quantity > 0
            AND ws.warehouse_id = %s
            AND ws.quantity > 0
        """
        params = [product_id, warehouse_id]

        if price is not None:
            # ✅ FIXED: our_price column dropped — filter by pb.price
            query += " AND pb.price = %s"
            params.append(float(price))

        query += " ORDER BY pb.batch_id ASC FOR UPDATE"
        cursor.execute(query, tuple(params))

    batches = cursor.fetchall()

    if not batches:
        price_msg = f" at price Rs. {price}" if price is not None else ""
        raise Exception(
            f"No stock found{price_msg} in warehouse {warehouse_id} for product {product_id}"
            f"{f', variation {variation_id}' if variation_id else ''}"
        )

    # Validate batch prices match expected price
    if price is not None:
        for batch in batches:
            batch_price = float(batch.get('price', 0))
            if abs(batch_price - float(price)) > 0.01:
                print(f"⚠️ WARNING: Batch {batch['batch_id']} price mismatch! Expected: {price}, Got: {batch_price}")

    # FIFO deduction across batches
    for batch in batches:
        if remaining_qty <= 0:
            break

        available_in_batch = min(float(batch['remaining_quantity']), float(batch['ws_qty']))
        deduct_now = min(remaining_qty, available_in_batch)

        cost_value = batch.get('cost', 0)
        if cost_value is None or cost_value == 0:
            print(f"⚠️ WARNING: Batch {batch['batch_id']} has zero/null cost!")

        # Update product_batches
        cursor.execute("""
            UPDATE product_batches
            SET remaining_quantity = remaining_quantity - %s
            WHERE batch_id = %s
        """, (deduct_now, batch['batch_id']))

        if cursor.rowcount != 1:
            raise Exception(f"Failed to update product_batches for batch {batch['batch_id']}")

        # Update warehouse_stock
        cursor.execute("""
            UPDATE warehouse_stock
            SET quantity = quantity - %s
            WHERE id = %s
        """, (deduct_now, batch['warehouse_stock_id']))

        if cursor.rowcount != 1:
            raise Exception(f"Failed to update warehouse_stock for warehouse_stock_id {batch['warehouse_stock_id']}")

        # Insert into sale_product_items for audit trail
        cursor.execute("""
            INSERT INTO sale_product_items (
                sale_product_id, batch_id, grn_id, quantity, cost, expiration_date
            ) VALUES (%s, %s, %s, %s, %s, %s)
        """, (
            sale_product_id,
            batch['batch_id'],
            batch['grn_id'],
            deduct_now,
            cost_value,
            batch.get('expiration_date')
        ))

        batches_used.append({
            'batch_id':        batch['batch_id'],
            'grn_id':          batch['grn_id'],
            'quantity':        deduct_now,
            'cost':            cost_value,
            'price':           batch.get('price'),
            'expiration_date': batch.get('expiration_date')
        })
        remaining_qty -= deduct_now

    # Allow small rounding errors
    if remaining_qty > 0.001:
        price_msg = f" at price Rs. {price}" if price is not None else ""
        raise Exception(
            f"Insufficient stock{price_msg} in warehouse {warehouse_id} for product {product_id}"
            f"{f', variation {variation_id}' if variation_id else ''}. "
            f"Short by {remaining_qty:.3f} units."
        )

    return batches_used


def restore_stock_to_batches_warehouse(cursor, sale_product_id, store_id):
    """
    Restore stock to batches for a specific sale_product using sale_product_items records (LIFO restoration).
    Also updates warehouse_stock table.

    ✅ FIXED: Properly uses actual store_id parameter

    Args:
        cursor: Database cursor
        sale_product_id: ID from sale_products table
        store_id: Store ID for warehouse_stock insertion
    """

    # Get all sale_product_items for this sale_product (reverse order for LIFO)
    cursor.execute("""
        SELECT spi.sale_product_item_id, spi.batch_id, spi.quantity,
               sp.warehouse_id, sp.product_id, sp.variation_id
        FROM sale_product_items spi
        INNER JOIN sale_products sp ON spi.sale_product_id = sp.sale_product_id
        WHERE spi.sale_product_id = %s
        ORDER BY spi.sale_product_item_id DESC
    """, (sale_product_id,))

    items = cursor.fetchall()

    if not items:
        print(f"   ℹ️ No sale_product_items found for sale_product_id {sale_product_id}")
        return

    for item in items:
        restore_qty  = float(item['quantity'])
        batch_id     = item['batch_id']
        warehouse_id = item['warehouse_id']

        # Update product_batches
        cursor.execute("""
            UPDATE product_batches
            SET remaining_quantity = remaining_quantity + %s
            WHERE batch_id = %s
        """, (restore_qty, batch_id))

        if cursor.rowcount != 1:
            print(f"   ⚠️ Warning: Batch {batch_id} not found in product_batches")

        # Update warehouse_stock (or insert if doesn't exist)
        cursor.execute("""
            SELECT id FROM warehouse_stock
            WHERE batch_id = %s AND warehouse_id = %s
        """, (batch_id, warehouse_id))

        ws_exists = cursor.fetchone()

        if ws_exists:
            cursor.execute("""
                UPDATE warehouse_stock
                SET quantity = quantity + %s
                WHERE batch_id = %s AND warehouse_id = %s
            """, (restore_qty, batch_id, warehouse_id))
        else:
            # Re-create warehouse_stock entry using actual store_id
            print(f"   ℹ️ Recreating warehouse_stock for batch {batch_id} in warehouse {warehouse_id}")
            cursor.execute("""
                INSERT INTO warehouse_stock (
                    store_id, product_id, warehouse_id, variation_id, batch_id, quantity
                ) VALUES (%s, %s, %s, %s, %s, %s)
            """, (
                store_id,
                item['product_id'],
                warehouse_id,
                item['variation_id'],
                batch_id,
                restore_qty
            ))

        print(f"   📦 Restored {restore_qty} units to batch {batch_id} in warehouse {warehouse_id}")

    # Delete the sale_product_items records
    cursor.execute("""
        DELETE FROM sale_product_items
        WHERE sale_product_id = %s
    """, (sale_product_id,))

    print(f"   🗑️ Deleted {cursor.rowcount} sale_product_items entries")


def get_payment_method_id_and_name(cursor, payment_type_id):
    """
    Get payment method ID and name from ID with enhanced error handling.
    ✅ FIXED: Returns both ID and name for proper database insertion

    Args:
        cursor: Database cursor
        payment_type_id: Payment method ID (int or string)

    Returns:
        tuple: (method_id, method_name) or (1, 'Cash') as fallback
    """
    try:
        method_id = int(payment_type_id)

        cursor.execute("""
            SELECT id, method_name
            FROM payment_methods
            WHERE id = %s AND is_active = 1
        """, (method_id,))

        result = cursor.fetchone()
        if result:
            return result['id'], result['method_name']

        print(f"⚠️ WARNING: Payment method ID {method_id} not found - defaulting to Cash")
        cursor.execute("""
            SELECT id, method_name
            FROM payment_methods
            WHERE LOWER(method_name) = 'cash' AND is_active = 1
            LIMIT 1
        """)
        cash_result = cursor.fetchone()
        if cash_result:
            return cash_result['id'], cash_result['method_name']

        return 1, 'Cash'

    except (ValueError, TypeError) as e:
        print(f"⚠️ Error converting payment_type_id to int: {e} - defaulting to Cash")
        cursor.execute("""
            SELECT id, method_name
            FROM payment_methods
            WHERE LOWER(method_name) = 'cash' AND is_active = 1
            LIMIT 1
        """)
        cash_result = cursor.fetchone()
        if cash_result:
            return cash_result['id'], cash_result['method_name']

        return 1, 'Cash'


def get_unit_identifier(cursor, unit_value):
    """
    Get unit SHORT NAME (unit_short) from unit ID.
    ✅ ALWAYS returns the short name like "kg", "pcs", "box"
    ✅ NEVER returns the full name like "Kilogram", "Pieces", "Box"

    Args:
        cursor: Database cursor
        unit_value: Unit ID (int or string)

    Returns:
        str: Unit short name to store in sale_unit column
    """
    if unit_value is None:
        return None

    try:
        unit_id = int(unit_value)

        cursor.execute("""
            SELECT unit_short
            FROM units
            WHERE id = %s
        """, (unit_id,))

        result = cursor.fetchone()
        if result and result['unit_short']:
            return result['unit_short']

        print(f"⚠️ WARNING: Unit ID {unit_id} not found or has no short name - using ID as fallback")
        return str(unit_id)

    except (ValueError, TypeError):
        return str(unit_value)


# ============================================
# MAIN ENDPOINTS
# ============================================

@sale_bp.route('/submit_invoice', methods=['POST'])
@jwt_required()
@role_required('admin', 'cashier')
def submit_invoice():
    """
    Create or update an invoice with warehouse-specific PRICE-AWARE FIFO stock management
    and multiple payment methods support.

    ✅ MULTI-WAREHOUSE: Each product can have its own warehouse_id
    ✅ BATCH TRACKING: Uses sale_product_items table for complete audit trail
    ✅ RACE CONDITION PREVENTION: Uses row-level locking (FOR UPDATE)
    ✅ MULTIPLE PAYMENTS: Stores in sale_payments table with proper IDs
    ✅ PRICE-AWARE FIFO: Only deducts from batches matching selected price
    ✅ PROPER RESTORATION: Uses actual store_id when restoring stock
    ✅ UNIT HANDLING: Properly converts unit ID to storable format
    ✅ our_price column removed — single pb.price used throughout
    """
    conn = None
    try:
        data = request.get_json()
        print("📥 Incoming Invoice JSON Data:")
        print(json.dumps(data, indent=4, default=str))

        # Extract invoice data
        invoice_id      = data.get('invoice_id')
        tenderAmount    = float(data.get('tenderAmount', 0))
        invoiceTotal    = float(data.get('invoiceTotal', 0))
        remainingAmount = float(data.get('remainingAmount', 0))
        discount_value  = float(data.get('discountValue', 0))
        customer_id     = data.get('customer_id')
        products        = data.get('products', [])
        note            = data.get('note', '')
        payment_notes   = data.get('payment_notes', '')
        tax             = float(data.get('tax', 0))
        warehouse_id    = data.get('warehouse_id')
        store_id        = data.get('store_id')
        payment_status  = data.get('payment_status', 'paid')
        payment_methods = data.get('payment_methods')  # Can be None

        # Validate required fields
        if not store_id:
            return jsonify({"error": "Store ID is required"}), 400

        if not products or len(products) == 0:
            return jsonify({"error": "At least one product is required"}), 400

        # Auto-determine status based on payment_status
        status = 'received' if payment_status == 'paid' else 'suspended'

        # Get current user ID from JWT
        current_user = get_jwt_identity()

        conn   = get_db_connection()
        cursor = conn.cursor(dictionary=True)

        # Start transaction with proper isolation level
        cursor.execute("SET TRANSACTION ISOLATION LEVEL READ COMMITTED")
        cursor.execute("START TRANSACTION")

        # Generate invoice code if new invoice
        invoice_code = data.get('invoice_code')
        if not invoice_id and not invoice_code:
            today_str = datetime.now().strftime("%Y%m%d")
            cursor.execute("SELECT COUNT(*) as count FROM invoice_sale WHERE DATE(created_at) = CURDATE()")
            count = cursor.fetchone()["count"] + 1
            invoice_code = f"INV-{today_str}-{count:04d}"

        # ✅ Determine payment_method for invoice_sale table (stores display string)
        if not payment_methods or len(payment_methods) == 0:
            primary_payment_method = "unpaid"
        elif len(payment_methods) == 1:
            method_id, method_name = get_payment_method_id_and_name(cursor, payment_methods[0]['type'])
            primary_payment_method = method_name.lower()
        else:
            primary_payment_method = "multiple"

        # Insert or Update Invoice
        if invoice_id:
            # UPDATE EXISTING INVOICE
            print(f"   🔄 Updating existing invoice {invoice_id}...")

            cursor.execute("""
                SELECT sale_product_id
                FROM sale_products
                WHERE invoice_id = %s
            """, (invoice_id,))
            old_sale_products = cursor.fetchall()

            # Restore stock using sale_product_items with proper store_id
            print(f"   ⏪ Restoring stock from {len(old_sale_products)} products...")
            for old_sp in old_sale_products:
                restore_stock_to_batches_warehouse(cursor, old_sp['sale_product_id'], store_id)

            # Delete old products and payments
            cursor.execute("DELETE FROM sale_products WHERE invoice_id=%s", (invoice_id,))
            print(f"   🗑️ Deleted {cursor.rowcount} old sale_products")

            cursor.execute("DELETE FROM sale_payments WHERE sale_id=%s", (invoice_id,))
            print(f"   🗑️ Deleted {cursor.rowcount} old sale_payments")

            # Update invoice
            cursor.execute("""
                UPDATE invoice_sale SET
                    invoice_code=%s, tenderAmount=%s, remainingAmount=%s, invoiceTotal=%s,
                    payment_method=%s, payment_status=%s, customer_id=%s, discount=%s,
                    status=%s, note=%s, tax=%s, warehouse_id=%s, store_id=%s
                WHERE invoice_id=%s
            """, (
                invoice_code, tenderAmount, remainingAmount, invoiceTotal,
                primary_payment_method, payment_status, customer_id, discount_value, status,
                f"{note}\nPayment Notes: {payment_notes}" if payment_notes else note,
                tax, warehouse_id, store_id, invoice_id
            ))
            print(f"   ✅ Updated invoice {invoice_code}")

        else:
            # CREATE NEW INVOICE
            print(f"   ➕ Creating new invoice...")
            cursor.execute("""
                INSERT INTO invoice_sale (
                    invoice_code, tenderAmount, remainingAmount, invoiceTotal,
                    payment_method, payment_status, customer_id, discount, status, note, tax,
                    warehouse_id, store_id, cashier_user_id
                ) VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)
            """, (
                invoice_code, tenderAmount, remainingAmount, invoiceTotal,
                primary_payment_method, payment_status, customer_id, discount_value, status,
                f"{note}\nPayment Notes: {payment_notes}" if payment_notes else note,
                tax, warehouse_id, store_id, current_user
            ))
            invoice_id = cursor.lastrowid
            print(f"   ✅ Created new invoice {invoice_code} (ID: {invoice_id})")

        # ✅ Insert Payment Methods with INTEGER IDs
        if payment_methods and payment_status != 'unpaid':
            print(f"   💳 Processing {len(payment_methods)} payment method(s)...")
            for idx, payment in enumerate(payment_methods, 1):
                try:
                    payment_type_value = payment.get('type')
                    payment_amount     = float(payment.get('amount', 0))

                    if payment_amount > 0:
                        method_id, method_name = get_payment_method_id_and_name(cursor, payment_type_value)
                        print(f"      [{idx}] {method_name} (ID: {method_id}): Rs. {payment_amount:.2f}")

                        cursor.execute("""
                            INSERT INTO sale_payments (
                                sale_id, payment_method, amount, payment_date
                            ) VALUES (%s, %s, %s, NOW())
                        """, (invoice_id, method_id, payment_amount))

                        print(f"         ✅ Payment record created (ID: {cursor.lastrowid})")

                except ValueError as ve:
                    print(f"      ⚠️ Invalid payment method value: {payment.get('type')} - {ve}")
                    continue
                except Exception as pe:
                    print(f"      ❌ Error processing payment {idx}: {pe}")
                    raise

        # ✅ Insert Products — our_price column removed, single price used
        print(f"   📦 Processing {len(products)} products...")
        for idx, product in enumerate(products, 1):
            product_id   = int(str(product['product_id']).split('-')[0])
            variation_id = product.get('variation_id')
            if variation_id:
                variation_id = int(variation_id)
            quantity = float(product['quantity'])

            # ✅ Single selling price (our_price field removed from frontend payload)
            selling_price  = float(product.get('price', product.get('unit_price', 0)))
            selected_price = selling_price   # used for FIFO price filter

            unit_price  = float(product.get('unit_price', selling_price))
            total_price = float(product.get('total', 0))
            discount    = float(product.get('discount', 0))
            product_tax = float(product.get('tax', 0))

            # Get warehouse_id from PRODUCT, not invoice
            product_warehouse_id = product.get('warehouse_id') or warehouse_id

            if not product_warehouse_id:
                raise Exception(f"Warehouse ID missing for product {product_id} (index {idx})")

            # ✅ Get unit identifier from sales_unit
            sales_unit_value   = product.get('sales_unit')
            sale_unit_identifier = get_unit_identifier(cursor, sales_unit_value)

            print(f"   📦 [{idx}/{len(products)}] Product {product_id}, Qty {quantity}, Warehouse {product_warehouse_id}")
            print(f"      Selling Price: Rs.{selling_price:.2f}")
            print(f"      Sales Unit: {sales_unit_value} -> Storing as: {sale_unit_identifier}")

            # Check stock in PRODUCT'S warehouse AT SELECTED PRICE
            available = get_available_stock_by_warehouse(
                cursor, product_id, variation_id, product_warehouse_id, price=selected_price
            )

            if available < quantity:
                raise Exception(
                    f"Insufficient stock at price Rs. {selected_price} in warehouse {product_warehouse_id} "
                    f"for product {product_id}. Available: {available:.2f}, Requested: {quantity:.2f}"
                )

            # ✅ Insert sale_products — our_price column removed
            cursor.execute("""
                INSERT INTO sale_products (
                    invoice_id, product_id, variation_id,
                    warehouse_id,
                    price, quantity, total,
                    discount_type, product_discount, tax_type, product_tax,
                    discount, tax, unit_price, sale_unit
                ) VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)
            """, (
                invoice_id, product_id, variation_id,
                product_warehouse_id,
                selling_price, quantity, total_price,
                product.get('discount_type'),
                float(product.get('product_discount', 0)),
                product.get('tax_type'),
                float(product.get('product_tax', 0)),
                discount, product_tax, unit_price, sale_unit_identifier
            ))

            sale_product_id = cursor.lastrowid

            # PRICE-AWARE FIFO Deduction using pb.price
            batches_used = deduct_stock_from_batches_fifo_warehouse(
                cursor, product_id, variation_id, quantity,
                product_warehouse_id, store_id, sale_product_id,
                price=selected_price
            )

            print(f"      ✅ Used {len(batches_used)} batches from WH-{product_warehouse_id}")
            for b in batches_used:
                print(f"         • B{b['batch_id']}: {b['quantity']:.2f} units @ Cost Rs.{b['cost']:.2f}, Price Rs.{b.get('price', 'N/A')}")

        # Commit transaction
        conn.commit()

        cursor.close()
        conn.close()

        print(f"✅ Invoice {invoice_code} created/updated successfully")
        print(f"   💰 Total: {invoiceTotal}, Paid: {tenderAmount}, Balance: {remainingAmount}")

        return jsonify({
            "message":        "Invoice created/updated successfully",
            "invoice_id":     invoice_id,
            "invoice_code":   invoice_code,
            "warehouse_id":   warehouse_id,
            "store_id":       store_id,
            "payment_status": payment_status,
            "invoice_status": status,
            "payment_method": primary_payment_method
        }), 200

    except Exception as e:
        if conn:
            conn.rollback()
            conn.close()
        print("❌ Error in /submit_invoice:", e)
        import traceback
        traceback.print_exc()
        return jsonify({"error": str(e)}), 500


@sale_bp.route('/view_invoice_with_payments/<int:invoice_id>', methods=['GET'])
@jwt_required()
@role_required('admin', 'cashier')
def view_invoice_with_payments(invoice_id):
    """
    Get detailed invoice information including payment breakdown and batch-level tracking.

    ✅ our_price / market_price columns removed — single sp.price used
    ✅ p.sku           → product_sku   (always base product sku)
    ✅ pv.variation_sku → variation_sku (only for variation products, NULL for single)
    ✅ Frontend display rule:
         single product    → product_name (product_sku)
         variation product → product_name (product_sku) - variation_type (variation_sku)
         ex: Junsui Naturals Cool Face Wash (jncfw) - 100ml (jncfw100)
    """
    conn   = None
    cursor = None

    try:
        conn   = get_db_connection()
        cursor = conn.cursor(dictionary=True)

        # ========================================
        # 1. GET INVOICE HEADER
        # ========================================
        cursor.execute("""
            SELECT
                i.invoice_id,
                i.invoice_code,
                i.tenderAmount,
                i.invoiceTotal,
                i.remainingAmount,
                i.payment_method,
                i.payment_status,
                i.sale_date,
                i.customer_id,
                i.status,
                i.note,
                i.discount,
                i.tax,
                i.warehouse_id,
                i.store_id,
                i.cashier_user_id,
                c.name     AS customer_name,
                c.contact  AS customer_contact,
                c.address  AS customer_address,
                c.email    AS customer_email,
                c.dob      AS customer_dob,
                c.gender   AS customer_gender,
                w.warehouse_name,
                w.city     AS warehouse_city,
                w.country  AS warehouse_country,
                s.store_name,
                s.contact  AS store_contact,
                s.email    AS store_email,
                u.name     AS cashier_name,
                u.email    AS cashier_email
            FROM invoice_sale i
            LEFT JOIN customers  c ON i.customer_id      = c.id
            LEFT JOIN warehouses w ON i.warehouse_id     = w.id
            LEFT JOIN stores     s ON i.store_id         = s.id
            LEFT JOIN users      u ON i.cashier_user_id  = u.id
            WHERE i.invoice_id = %s
        """, (invoice_id,))

        invoice = cursor.fetchone()

        if not invoice:
            return jsonify({
                "success": False,
                "error":   "Invoice not found",
                "message": f"No invoice exists with ID {invoice_id}"
            }), 404

        # ========================================
        # 2. GET PAYMENT METHODS BREAKDOWN
        # ========================================
        cursor.execute("""
            SELECT
                sp.payment_method       AS method_id,
                pm.method_name,
                sp.amount,
                sp.payment_date,
                sp.card_number_last4,
                sp.card_type,
                sp.transaction_reference
            FROM sale_payments sp
            INNER JOIN payment_methods pm ON sp.payment_method = pm.id
            WHERE sp.sale_id = %s
            ORDER BY sp.payment_date ASC
        """, (invoice_id,))

        payments = cursor.fetchall()

        # ========================================
        # 3. BUILD INVOICE DETAILS OBJECT
        # ========================================
        invoice_details = {
            "invoice_id":   invoice['invoice_id'],
            "invoice_code": invoice['invoice_code'],
            "sale_date":    invoice['sale_date'].strftime('%Y-%m-%d %H:%M:%S') if invoice['sale_date'] else None,
            "status":       invoice['status'],
            "note":         invoice['note'],

            "invoiceTotal":    float(invoice['invoiceTotal'])    if invoice['invoiceTotal']    else 0.00,
            "tenderAmount":    float(invoice['tenderAmount'])    if invoice['tenderAmount']    else 0.00,
            "remainingAmount": float(invoice['remainingAmount']) if invoice['remainingAmount'] else 0.00,
            "discount":        float(invoice['discount'])        if invoice['discount']        else 0.00,
            "tax":             float(invoice['tax'])             if invoice['tax']             else 0.00,

            "payment_method": invoice['payment_method'],
            "payment_status": invoice.get('payment_status', 'paid'),
            "payments": [
                {
                    "method_id":       p['method_id'],
                    "method_name":     p['method_name'],
                    "amount":          float(p['amount']) if p['amount'] else 0.00,
                    "payment_date":    p['payment_date'].strftime('%Y-%m-%d %H:%M:%S') if p['payment_date'] else None,
                    "card_last4":      p.get('card_number_last4'),
                    "card_type":       p.get('card_type'),
                    "transaction_ref": p.get('transaction_reference')
                }
                for p in payments
            ],
            "payment_count": len(payments),

            "customer": {
                "id":      invoice['customer_id'],
                "name":    invoice['customer_name'],
                "contact": invoice['customer_contact'],
                "email":   invoice['customer_email'],
                "address": invoice['customer_address'],
                "dob":     invoice['customer_dob'].strftime('%Y-%m-%d') if invoice.get('customer_dob') else None,
                "gender":  invoice.get('customer_gender')
            } if invoice['customer_id'] else None,

            "warehouse": {
                "id":      invoice['warehouse_id'],
                "name":    invoice['warehouse_name'],
                "city":    invoice.get('warehouse_city'),
                "country": invoice.get('warehouse_country')
            } if invoice['warehouse_id'] else None,

            "store": {
                "id":      invoice['store_id'],
                "name":    invoice['store_name'],
                "contact": invoice.get('store_contact'),
                "email":   invoice.get('store_email')
            } if invoice['store_id'] else None,

            "cashier": {
                "id":    invoice['cashier_user_id'],
                "name":  invoice['cashier_name'],
                "email": invoice['cashier_email']
            } if invoice.get('cashier_user_id') else None,

            "products": []
        }

        # ========================================
        # 4. GET PRODUCTS
        # ✅ FIX: p.sku AS product_sku        → always base product sku
        # ✅ FIX: pv.variation_sku             → variation sku (NULL for single)
        # ✅ pv.variation_name                 → e.g. "Size", "Color"
        # ✅ pv.variation_type                 → e.g. "100ml", "Red"
        # ✅ COALESCE removed — both sent separately
        # ========================================
        cursor.execute("""
            SELECT
                sp.sale_product_id,
                sp.product_id,
                sp.variation_id,
                sp.warehouse_id,
                p.product_name,
                p.sku               AS product_sku,
                pv.variation_name,
                pv.variation_type,
                pv.variation_sku,
                p.product_image,
                p.barcode_symbology,
                sp.price            AS selling_price,
                sp.quantity,
                sp.total,
                sp.unit_price,
                sp.sale_unit,
                sp.tax_type,
                sp.product_tax,
                sp.product_discount,
                sp.tax,
                sp.discount_type,
                sp.discount,
                w.warehouse_name,
                c.category_name,
                b.brand_name,
                bu.base_unit,
                u.unit_name,
                u.unit_short
            FROM sale_products sp
            INNER JOIN products p           ON sp.product_id   = p.id
            LEFT JOIN product_variations pv ON sp.variation_id = pv.id
            LEFT JOIN warehouses w          ON sp.warehouse_id = w.id
            LEFT JOIN categories c          ON p.category_id   = c.id
            LEFT JOIN brands b              ON p.brand_id      = b.id
            LEFT JOIN base_units bu         ON p.base_unit_id  = bu.id
            LEFT JOIN units u               ON p.sale_unit_id  = u.id
            WHERE sp.invoice_id = %s
            ORDER BY sp.sale_product_id ASC
        """, (invoice_id,))

        products = cursor.fetchall()

        # ========================================
        # 5. PROCESS EACH PRODUCT
        # ========================================
        total_profit = 0.00
        total_cost   = 0.00

        for product in products:
            product_warehouse = product.get('warehouse_id') or invoice['warehouse_id']

            available_stock = 0
            if product_warehouse:
                available_stock = get_available_stock_by_warehouse(
                    cursor,
                    product['product_id'],
                    product['variation_id'],
                    product_warehouse
                )

            # ========================================
            # 6. GET BATCH BREAKDOWN (FIFO TRACKING)
            # ========================================
            cursor.execute("""
                SELECT
                    spi.sale_product_item_id,
                    spi.batch_id,
                    spi.grn_id,
                    spi.quantity,
                    spi.cost              AS batch_cost,
                    spi.expiration_date,
                    pb.batch_number,
                    pb.price              AS batch_price,
                    pb.remaining_quantity AS batch_remaining_qty,
                    g.grn_code,
                    g.grn_date,
                    g.supplier_id,
                    s.supplier_name
                FROM sale_product_items spi
                LEFT JOIN product_batches pb ON spi.batch_id  = pb.batch_id
                LEFT JOIN grn g              ON spi.grn_id    = g.grn_id
                LEFT JOIN suppliers s        ON g.supplier_id = s.id
                WHERE spi.sale_product_id = %s
                ORDER BY spi.sale_product_item_id ASC
            """, (product['sale_product_id'],))

            batches = cursor.fetchall()

            product_total_cost = sum(
                float(b['batch_cost'] or 0) * float(b['quantity'] or 0)
                for b in batches
            )

            selling_price   = float(product.get('selling_price') or product.get('unit_price') or 0)
            quantity        = float(product['quantity'] or 0)
            product_revenue = selling_price * quantity
            product_profit  = product_revenue - product_total_cost

            total_cost   += product_total_cost
            total_profit += product_profit

            # ========================================
            # 7. BUILD PRODUCT OBJECT
            # ✅ product_sku   = p.sku  (always base product sku)
            # ✅ variation_sku = pv.variation_sku (NULL if single product)
            # ✅ variation_type = pv.variation_type (e.g. "100ml")
            #
            # Frontend display:
            #   single    → product_name (product_sku)
            #   variation → product_name (product_sku) - variation_type (variation_sku)
            #   ex: Junsui Naturals Cool Face Wash (jncfw) - 100ml (jncfw100)
            # ========================================
            product_data = {
                "sale_product_id":   product['sale_product_id'],
                "product_id":        product['product_id'],
                "variation_id":      product['variation_id'],
                "product_name":      product['product_name'],          # base name only
                "product_sku":       product['product_sku'],           # ✅ p.sku always
                "variation_name":    product.get('variation_name'),    # e.g. "Size"
                "variation_type":    product.get('variation_type'),    # e.g. "100ml"
                "variation_sku":     product.get('variation_sku'),     # ✅ pv.variation_sku
                "product_image":     product.get('product_image'),
                "barcode_symbology": product.get('barcode_symbology'),
                "category":          product.get('category_name'),
                "brand":             product.get('brand_name'),
                "base_unit":         product.get('base_unit'),
                "sale_unit":         product['sale_unit'],
                "unit_name":         product.get('unit_name'),
                "unit_short":        product.get('unit_short'),

                "selling_price": selling_price,

                "quantity":   quantity,
                "subtotal":   float(product['total'])      if product['total']      else 0.00,
                "unit_price": float(product['unit_price']) if product['unit_price'] else 0.00,

                "tax_type":         product['tax_type'],
                "product_tax":      float(product['product_tax'])      if product['product_tax']      else 0.00,
                "tax":              float(product['tax'])               if product['tax']               else 0.00,
                "discount_type":    product['discount_type'],
                "product_discount": float(product['product_discount']) if product['product_discount'] else 0.00,
                "discount":         float(product['discount'])         if product['discount']         else 0.00,

                "total_cost":               round(product_total_cost, 2),
                "profit":                   round(product_profit, 2),
                "profit_margin_percentage": round(
                    (product_profit / product_revenue * 100) if product_revenue > 0 else 0, 2
                ),

                "stock_quantity": available_stock,
                "warehouse_id":   product_warehouse,
                "warehouse_name": product.get('warehouse_name', 'Unknown'),

                "batches_used": [
                    {
                        "sale_product_item_id": b['sale_product_item_id'],
                        "batch_id":             b['batch_id'],
                        "batch_number":         b['batch_number'],
                        "grn_id":               b['grn_id'],
                        "grn_code":             b['grn_code'],
                        "grn_date":             b['grn_date'].strftime('%Y-%m-%d') if b.get('grn_date') else None,
                        "supplier_id":          b.get('supplier_id'),
                        "supplier_name":        b.get('supplier_name'),
                        "quantity":             float(b['quantity']),
                        "batch_cost":           float(b['batch_cost'])  if b.get('batch_cost')  else 0.00,
                        "batch_price":          float(b['batch_price']) if b.get('batch_price') else 0.00,
                        "expiration_date":      b['expiration_date'].strftime('%Y-%m-%d') if b.get('expiration_date') else None,
                        "batch_remaining_qty":  float(b['batch_remaining_qty']) if b.get('batch_remaining_qty') else 0.00,
                        "batch_revenue": round(selling_price * float(b['quantity']), 2),
                        "batch_profit":  round(
                            (selling_price * float(b['quantity'])) -
                            (float(b['batch_cost'] or 0) * float(b['quantity'])), 2
                        )
                    }
                    for b in batches
                ],
                "batch_count": len(batches)
            }

            invoice_details['products'].append(product_data)

        # ========================================
        # 8. PROFIT SUMMARY
        # ========================================
        invoice_details['profit_summary'] = {
            "total_cost":               round(total_cost, 2),
            "total_revenue":            float(invoice['invoiceTotal']) if invoice['invoiceTotal'] else 0.00,
            "gross_profit":             round(total_profit, 2),
            "profit_margin_percentage": round(
                (total_profit / float(invoice['invoiceTotal']) * 100)
                if invoice['invoiceTotal'] else 0, 2
            ),
            "total_discount_given": float(invoice['discount']) if invoice['discount'] else 0.00,
            "total_tax":            float(invoice['tax'])      if invoice['tax']      else 0.00
        }

        # ========================================
        # 9. SUMMARY COUNTS
        # ========================================
        invoice_details['summary'] = {
            "total_items":        len(products),
            "total_quantity":     sum(float(p['quantity'] or 0) for p in products),
            "unique_products":    len(set(p['product_id'] for p in products)),
            "total_batches_used": sum(len(p['batches_used']) for p in invoice_details['products'])
        }

        cursor.close()
        conn.close()

        return jsonify({
            "success": True,
            "invoice": invoice_details
        }), 200

    except Exception as e:
        print("❌ Error in view_invoice_with_payments:", str(e))
        import traceback
        traceback.print_exc()
        if cursor: cursor.close()
        if conn:   conn.close()
        return jsonify({
            "success": False,
            "error":   "Internal server error",
            "message": str(e)
        }), 500
@sale_bp.route('/check_stock_availability', methods=['POST'])
@jwt_required()
@role_required('admin', 'cashier')
def check_stock_availability():
    """
    Check if sufficient stock is available in specific warehouses for given products.
    ✅ Supports checking each product in its own warehouse
    ✅ Price-aware stock checking using pb.price (our_price column removed)
    """
    try:
        data     = request.get_json()
        products = data.get('products', [])

        if not products:
            return jsonify({"error": "Products array is required"}), 400

        conn   = get_db_connection()
        cursor = conn.cursor(dictionary=True)

        results       = []
        all_available = True

        for product in products:
            product_id   = int(str(product['product_id']).split('-')[0])
            variation_id = product.get('variation_id')
            if variation_id:
                variation_id = int(variation_id)
            requested_qty = float(product['quantity'])

            # ✅ Single price field — our_price fallback removed
            selected_price = product.get('price') or product.get('unit_price')

            product_warehouse_id = product.get('warehouse_id')

            if not product_warehouse_id:
                results.append({
                    'product_id':   product_id,
                    'variation_id': variation_id,
                    'product_name': product.get('productName', 'Unknown'),
                    'error':        'Warehouse ID missing'
                })
                all_available = False
                continue

            # Check stock in product's specific warehouse
            available    = get_available_stock_by_warehouse(
                cursor, product_id, variation_id, product_warehouse_id,
                price=selected_price if selected_price else None
            )
            is_available = available >= requested_qty

            if not is_available:
                all_available = False

            result_item = {
                'product_id':     product_id,
                'variation_id':   variation_id,
                'product_name':   product.get('productName', 'Unknown'),
                'warehouse_id':   product_warehouse_id,
                'requested_qty':  requested_qty,
                'available_stock':available,
                'is_available':   is_available,
                'shortage':       max(0, requested_qty - available)
            }

            if selected_price:
                result_item['price'] = float(selected_price)

            results.append(result_item)

        cursor.close()
        conn.close()

        return jsonify({
            'all_available': all_available,
            'products':      results
        }), 200

    except Exception as e:
        print("❌ Error checking stock:", e)
        import traceback
        traceback.print_exc()
        return jsonify({"error": str(e)}), 500

@sale_bp.route('/update_invoices/<int:invoice_id>', methods=['PUT'])
@jwt_required()
@role_required('admin', 'cashier')
def update_invoices(invoice_id):
    """
    Update invoice with multi-warehouse support, payment methods, and batch tracking.
    ✅ Compatible with database schema
    ✅ Uses sale_product_items for FIFO tracking
    ✅ Price-aware FIFO stock management
    ✅ FIXED: Payment methods use INTEGER IDs
    """
    conn = None
    try:
        data = request.get_json(force=True)
        print(f"📥 Update Invoice {invoice_id}")
        print(json.dumps(data, indent=4, default=str))

        # Extract data
        invoice_code = data.get('invoice_code')
        tenderAmount = float(data.get('tenderAmount', 0))
        remainingAmount = float(data.get('remainingAmount', 0))
        invoiceTotal = float(data.get('invoiceTotal', 0))
        discount_value = float(data.get('discountValue', 0))
        customer_id = data.get('customer_id')
        products = data.get('products', [])
        note = data.get('note', '')
        payment_notes = data.get('payment_notes', '')
        tax = float(data.get('tax', 0))
        warehouse_id = data.get('warehouse_id')
        store_id = data.get('store_id')
        payment_status = data.get('payment_status', 'paid')
        payment_methods = data.get('payment_methods', [])

        if not store_id:
            return jsonify({"error": "Store ID is required"}), 400

        if not products or len(products) == 0:
            return jsonify({"error": "At least one product is required"}), 400

        # Auto-determine status
        if payment_status == 'paid':
            status = 'received'
        else:
            status = 'suspended'

        conn = get_db_connection()
        cursor = conn.cursor(dictionary=True)

        # Start transaction
        cursor.execute("SET TRANSACTION ISOLATION LEVEL READ COMMITTED")
        cursor.execute("START TRANSACTION")

        # Check if invoice exists
        cursor.execute("""
            SELECT i.*, s.store_name 
            FROM invoice_sale i
            LEFT JOIN stores s ON i.store_id = s.id
            WHERE i.invoice_id = %s
        """, (invoice_id,))
        
        existing_invoice = cursor.fetchone()
        if not existing_invoice:
            return jsonify({"error": f"Invoice {invoice_id} not found"}), 404

        # Restore stock using sale_product_items
        cursor.execute("""
            SELECT sale_product_id, product_id
            FROM sale_products
            WHERE invoice_id = %s
        """, (invoice_id,))
        
        old_sale_products = cursor.fetchall()
        
        print(f"   ⏪ Restoring stock from {len(old_sale_products)} products...")
        for old_sp in old_sale_products:
            restore_stock_to_batches_warehouse(cursor, old_sp['sale_product_id'], store_id)

        # Delete old records
        cursor.execute("DELETE FROM sale_products WHERE invoice_id = %s", (invoice_id,))
        deleted_products = cursor.rowcount
        cursor.execute("DELETE FROM sale_payments WHERE sale_id = %s", (invoice_id,))
        deleted_payments = cursor.rowcount
        
        print(f"   🗑️ Deleted {deleted_products} old products, {deleted_payments} old payments")

        # Determine primary payment method
        if len(payment_methods) == 0:
            primary_payment_method = "unpaid"
        elif len(payment_methods) == 1:
            method_id, method_name = get_payment_method_id_and_name(cursor, payment_methods[0]['type'])
            primary_payment_method = method_name.lower()
        else:
            primary_payment_method = "multiple"

        # Update invoice
        cursor.execute("""
            UPDATE invoice_sale SET
                invoice_code = %s, tenderAmount = %s, remainingAmount = %s, invoiceTotal = %s,
                payment_method = %s, payment_status = %s, customer_id = %s, discount = %s,
                status = %s, note = %s, tax = %s, warehouse_id = %s, store_id = %s
            WHERE invoice_id = %s
        """, (
            invoice_code, tenderAmount, remainingAmount, invoiceTotal,
            primary_payment_method, payment_status, customer_id, discount_value, 
            status, f"{note}\nPayment Notes: {payment_notes}" if payment_notes else note,
            tax, warehouse_id, store_id, invoice_id
        ))

        # ✅ FIXED: Insert payment methods with INTEGER IDs
        if payment_methods and payment_status != 'unpaid':
            for payment in payment_methods:
                payment_type_value = payment.get('type')
                payment_amount = float(payment.get('amount', 0))
                
                if payment_amount > 0:
                    method_id, method_name = get_payment_method_id_and_name(cursor, payment_type_value)
                    cursor.execute("""
                        INSERT INTO sale_payments (sale_id, payment_method, amount, payment_date)
                        VALUES (%s, %s, %s, NOW())
                    """, (invoice_id, method_id, payment_amount))  # ✅ Use method_id (INTEGER)
                    print(f"   💳 Added payment: {method_name} Rs.{payment_amount}")

        # Insert new products with sale_product_items tracking
        print(f"   📦 Processing {len(products)} products...")
        for idx, product in enumerate(products, 1):
            product_id = int(str(product['product_id']).split('-')[0])
            variation_id = product.get('variation_id')
            if variation_id:
                variation_id = int(variation_id)
            
            quantity = float(product['quantity'])
            
            if quantity <= 0:
                raise Exception(f"Invalid quantity {quantity} for product {product_id}")
            
            selected_price = float(product.get('price', 0))
            
            unit_price = float(product.get('unit_price', product.get('price', 0)))
            total_price = float(product.get('total', 0))
            discount = float(product.get('discount', 0))
            product_tax = float(product.get('tax', 0))
            
            product_warehouse_id = product.get('warehouse_id') or warehouse_id
            
            if not product_warehouse_id:
                raise Exception(f"Warehouse ID missing for product {product_id}")

            # Check stock at specific price
            available = get_available_stock_by_warehouse(
                cursor, product_id, variation_id, product_warehouse_id, price=selected_price
            )
            
            if available < quantity:
                raise Exception(
                    f"Insufficient stock at price Rs. {selected_price} in warehouse {product_warehouse_id} for product {product_id}. "
                    f"Available: {available:.2f}, Requested: {quantity:.2f}"
                )

            # Insert sale_products
            cursor.execute("""
                INSERT INTO sale_products (
                    invoice_id, product_id, variation_id, warehouse_id,
                    price, quantity, total, discount_type, product_discount, 
                    tax_type, product_tax, discount, tax, unit_price, sale_unit
                ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
            """, (
                invoice_id, product_id, variation_id, product_warehouse_id,
                selected_price, quantity, total_price,
                product.get('discount_type'), float(product.get('product_discount', 0)),
                product.get('tax_type'), float(product.get('product_tax', 0)),
                discount, product_tax, unit_price, product.get('sales_unit')
            ))
            
            sale_product_id = cursor.lastrowid

            # Deduct stock and create sale_product_items
            batches_used = deduct_stock_from_batches_fifo_warehouse(
                cursor, product_id, variation_id, quantity, 
                product_warehouse_id, store_id, sale_product_id,
                price=selected_price
            )
            
            print(f"      [{idx}/{len(products)}] Product {product_id}: {quantity} units @ Rs.{selected_price} from WH-{product_warehouse_id}")

        # Commit transaction
        conn.commit()
        cursor.close()
        conn.close()

        print(f"✅ Invoice {invoice_code} updated successfully!")
        
        return jsonify({
            "message": f"Invoice {invoice_code} updated successfully",
            "invoice_id": invoice_id,
            "invoice_code": invoice_code,
            "payment_status": payment_status,
            "invoice_status": status,
            "payment_method": primary_payment_method
        }), 200

    except Exception as e:
        if conn:
            conn.rollback()
            conn.close()
        print(f"❌ Error updating invoice: {e}")
        import traceback
        traceback.print_exc()
        return jsonify({"error": str(e)}), 500


@sale_bp.route('/view_invoice/<int:invoice_id>', methods=['GET'])
@jwt_required()
@role_required('admin', 'cashier')
def view_invoice(invoice_id):
    """
    Get detailed invoice information with batch breakdown.
    ✅ Shows complete FIFO audit trail via sale_product_items
    ✅ FIXED: Properly handles payment_method as INTEGER FK
    """
    try:
        conn = get_db_connection()
        cursor = conn.cursor(dictionary=True)

        # Get invoice header
        cursor.execute("""
            SELECT 
                i.invoice_id, i.invoice_code, i.tenderAmount, i.invoiceTotal, i.remainingAmount, 
                i.payment_method, i.payment_status, i.sale_date, i.customer_id, i.status, 
                i.note, i.discount, i.tax, i.warehouse_id, i.store_id,
                c.name AS customer_name, c.contact, c.address, c.dob,
                w.warehouse_name, s.store_name, u.name AS created_by_name
            FROM invoice_sale i
            LEFT JOIN customers c ON i.customer_id = c.id
            LEFT JOIN warehouses w ON i.warehouse_id = w.id
            LEFT JOIN stores s ON i.store_id = s.id
            LEFT JOIN users u ON i.cashier_user_id = u.id
            WHERE i.invoice_id = %s
        """, (invoice_id,))
        
        invoice = cursor.fetchone()

        if not invoice:
            return jsonify({"error": "Invoice not found"}), 404

        # ✅ FIXED: Get payment methods with JOIN
        cursor.execute("""
            SELECT sp.payment_method as method_id, pm.method_name, 
                   sp.amount, sp.payment_date
            FROM sale_payments sp
            INNER JOIN payment_methods pm ON sp.payment_method = pm.id
            WHERE sp.sale_id = %s
            ORDER BY sp.payment_date ASC
        """, (invoice_id,))
        
        payments = cursor.fetchall()

        invoice_details = {
            "invoice_id": invoice['invoice_id'],
            "invoice_code": invoice['invoice_code'],
            "tenderAmount": float(invoice['tenderAmount']) if invoice['tenderAmount'] else 0,
            "remainingAmount": float(invoice['remainingAmount']) if invoice['remainingAmount'] else 0,
            "invoiceTotal": float(invoice['invoiceTotal']) if invoice['invoiceTotal'] else 0,
            "payment_method": invoice['payment_method'],
            "payment_status": invoice.get('payment_status', 'paid'),
            "sale_date": invoice['sale_date'].strftime('%H:%M:%S %d/%m/%Y') if invoice['sale_date'] else '',
            "status": invoice['status'],
            "note": invoice['note'],
            "discount": float(invoice['discount']) if invoice['discount'] else 0,
            "tax": float(invoice['tax']) if invoice['tax'] else 0,
            "warehouse_id": invoice['warehouse_id'],
            "warehouse_name": invoice.get('warehouse_name', 'Unknown'),
            "store_id": invoice['store_id'],
            "store_name": invoice.get('store_name', 'Unknown'),
            "customer_id": invoice['customer_id'],
            "customer_name": invoice['customer_name'],
            "contact": invoice['contact'],
            "address": invoice['address'],
            "dob": str(invoice['dob']) if invoice['dob'] else '',
            "created_by": invoice.get('created_by_name', 'System'),
            "payments": [
                {
                    "method_id": p['method_id'],
                    "method": p['method_name'],
                    "amount": float(p['amount']),
                    "date": p['payment_date'].strftime('%Y-%m-%d %H:%M:%S') if p['payment_date'] else ''
                }
                for p in payments
            ],
            "payment_count": len(payments),
            "products": []
        }

        # Calculate age if DOB exists
        if invoice['dob']:
            try:
                birth_date = datetime.strptime(str(invoice['dob']), "%Y-%m-%d")
                today_date = datetime.today()
                invoice_details["age"] = today_date.year - birth_date.year - (
                    (today_date.month, today_date.day) < (birth_date.month, birth_date.day)
                )
            except:
                invoice_details["age"] = None

        # Get products with warehouse info
        cursor.execute("""
            SELECT 
                sp.sale_product_id, sp.product_id, sp.variation_id, sp.warehouse_id,
                p.product_name, pv.variation_name, pv.variation_type,
                COALESCE(pv.variation_sku, p.sku) AS sku,
                sp.price, sp.quantity, sp.total, sp.unit_price, sp.sale_unit,
                sp.tax_type, sp.product_tax, sp.product_discount,
                sp.tax, sp.discount_type, sp.discount,
                w.warehouse_name
            FROM sale_products sp
            JOIN products p ON sp.product_id = p.id
            LEFT JOIN product_variations pv ON sp.variation_id = pv.id
            LEFT JOIN warehouses w ON sp.warehouse_id = w.id
            WHERE sp.invoice_id = %s
            ORDER BY sp.sale_product_id ASC
        """, (invoice_id,))
        
        products = cursor.fetchall()

        for product in products:
            product_warehouse_id = product.get('warehouse_id') or invoice['warehouse_id']
            
            # Get current stock
            available_stock = 0
            if product_warehouse_id:
                available_stock = get_available_stock_by_warehouse(
                    cursor, product['product_id'], product['variation_id'], product_warehouse_id
                )
            
            # Get batch breakdown from sale_product_items
            cursor.execute("""
                SELECT 
                    spi.batch_id, spi.grn_id, spi.quantity, spi.cost, spi.expiration_date,
                    pb.batch_number, pb.price as batch_price,
                    g.grn_code, g.grn_date,
                    sup.supplier_name
                FROM sale_product_items spi
                LEFT JOIN product_batches pb ON spi.batch_id = pb.batch_id
                LEFT JOIN grn g ON spi.grn_id = g.grn_id
                LEFT JOIN suppliers sup ON g.supplier_id = sup.id
                WHERE spi.sale_product_id = %s
                ORDER BY spi.sale_product_item_id ASC
            """, (product['sale_product_id'],))
            
            batches = cursor.fetchall()
            
            product_detail = {
                "product_id": product['product_id'],
                "variation_id": product['variation_id'],
                "product_name": product['product_name'],
                "variation_name": product['variation_name'],
                "variation_type": product['variation_type'],
                "sku": product['sku'],
                "price": float(product['price']) if product['price'] else 0,
                "quantity": float(product['quantity']) if product['quantity'] else 0,
                "subtotal": float(product['total']) if product['total'] else 0,
                "stock_quantity": available_stock,
                "warehouse_id": product_warehouse_id,
                "warehouse_name": product.get('warehouse_name', 'Unknown'),
                "sale_unit": product['sale_unit'],
                "unit_price": float(product['unit_price']) if product['unit_price'] else 0,
                "tax_type": product['tax_type'],
                "product_tax": float(product['product_tax']) if product['product_tax'] else 0,
                "product_discount": float(product['product_discount']) if product['product_discount'] else 0,
                "tax": float(product['tax']) if product['tax'] else 0,
                "discount_type": product['discount_type'],
                "discount": float(product['discount']) if product['discount'] else 0,
                "batches_used": [
                    {
                        "batch_id": b['batch_id'],
                        "batch_number": b['batch_number'],
                        "batch_price": float(b.get('batch_price', 0)),
                        "grn_id": b['grn_id'],
                        "grn_code": b['grn_code'],
                        "grn_date": b['grn_date'].strftime('%Y-%m-%d') if b.get('grn_date') else None,
                        "supplier_name": b.get('supplier_name'),
                        "quantity": float(b['quantity']),
                        "cost": float(b['cost']) if b.get('cost') else 0,
                        "expiration_date": b['expiration_date'].strftime('%Y-%m-%d') if b.get('expiration_date') else None
                    }
                    for b in batches
                ],
                "batch_count": len(batches)
            }
            
            invoice_details['products'].append(product_detail)

        cursor.close()
        conn.close()

        return jsonify({"invoice": invoice_details}), 200

    except Exception as e:
        print("❌ Error in view_invoice:", e)
        import traceback
        traceback.print_exc()
        return jsonify({"error": str(e)}), 500
    
    
# ============================================
# ✅ FIXED: DELETE INVOICE - WITH SALE_PRODUCT_ITEMS
# ============================================
@sale_bp.route('/delete_invoice/<int:invoice_id>', methods=['DELETE'])
@jwt_required()
@role_required('admin')
def delete_invoice(invoice_id):
    """
    Delete invoice and restore stock using sale_product_items.
    ✅ Properly restores stock to product_batches and warehouse_stock
    """
    conn = None
    try:
        conn = get_db_connection()
        cursor = conn.cursor(dictionary=True)

        # ✅ Start transaction with proper isolation level
        cursor.execute("SET TRANSACTION ISOLATION LEVEL READ COMMITTED")
        cursor.execute("START TRANSACTION")

        # Check if invoice exists
        cursor.execute("""
            SELECT warehouse_id, store_id, invoice_code, invoiceTotal
            FROM invoice_sale
            WHERE invoice_id = %s
        """, (invoice_id,))
        invoice_info = cursor.fetchone()
        
        if not invoice_info:
            return jsonify({"message": f"Invoice {invoice_id} not found"}), 404

        # ✅ Get sale_product IDs to restore stock
        cursor.execute("""
            SELECT sale_product_id, product_id
            FROM sale_products
            WHERE invoice_id = %s
        """, (invoice_id,))
        sale_products = cursor.fetchall()

        if not sale_products:
            print(f"   ℹ️ No products found for invoice {invoice_id}")

        # ✅ Restore stock using sale_product_items (WITH store_id parameter)
        print(f"   ⏪ Restoring stock from {len(sale_products)} products...")
        for item in sale_products:
            restore_stock_to_batches_warehouse(cursor, item['sale_product_id'], invoice_info['store_id'])
            print(f"      ✅ Restored stock for product {item['product_id']}")

        # Delete records (CASCADE will delete sale_product_items)
        cursor.execute("DELETE FROM sale_payments WHERE sale_id = %s", (invoice_id,))
        deleted_payments = cursor.rowcount
        
        cursor.execute("DELETE FROM sale_products WHERE invoice_id = %s", (invoice_id,))
        deleted_products = cursor.rowcount
        
        cursor.execute("DELETE FROM invoice_sale WHERE invoice_id = %s", (invoice_id,))
        
        print(f"   🗑️ Deleted {deleted_payments} payments, {deleted_products} products")

        # ✅ Commit transaction
        conn.commit()
        cursor.close()
        conn.close()

        print(f"✅ Invoice {invoice_info['invoice_code']} deleted successfully")
        return jsonify({
            "message": f"Invoice {invoice_info['invoice_code']} deleted successfully",
            "invoice_code": invoice_info['invoice_code'],
            "amount": float(invoice_info['invoiceTotal'])
        }), 200

    except Exception as e:
        if conn:
            conn.rollback()
            conn.close()
        print(f"❌ Error deleting invoice: {e}")
        import traceback
        traceback.print_exc()
        return jsonify({"error": str(e)}), 500


# ============================================
# ✅ FIXED: VIEW ALL INVOICES - WITH INTEGER PAYMENT_METHOD
# ============================================
@sale_bp.route('/view_all_invoices', methods=['GET'])
@jwt_required()
@role_required('admin', 'cashier')
def view_all_invoices():
    """
    Get all invoices with filters - includes payment breakdown.
    ✅ Optimized: Avoids N+1 query problem
    ✅ FIXED: Properly handles INTEGER payment_method FK
    """
    try:
        conn = get_db_connection()
        cursor = conn.cursor(dictionary=True)

        status_filter = request.args.get('status')
        date_filter = request.args.get('dateFilter')
        start_date = request.args.get('startDate')
        end_date = request.args.get('endDate')

        today = datetime.now().date()

        base_query = """
            SELECT 
                i.invoice_id, i.invoice_code, i.tenderAmount, i.remainingAmount, 
                i.invoiceTotal, i.payment_method, i.payment_status,
                i.customer_id, c.name AS customer_name, i.sale_date, i.status,
                i.warehouse_id, w.warehouse_name, i.store_id, s.store_name,
                u.name AS created_by_name, i.created_at
            FROM invoice_sale i
            LEFT JOIN customers c ON i.customer_id = c.id
            LEFT JOIN warehouses w ON i.warehouse_id = w.id
            LEFT JOIN stores s ON i.store_id = s.id
            LEFT JOIN users u ON i.cashier_user_id = u.id
        """

        conditions = []
        params = []

        if status_filter and status_filter.lower() != "all":
            conditions.append("i.status = %s")
            params.append(status_filter)

        if date_filter and date_filter != "allTime":
            if date_filter == "today":
                conditions.append("DATE(i.sale_date) = %s")
                params.append(today)
            elif date_filter == "yesterday":
                conditions.append("DATE(i.sale_date) = %s")
                params.append(today - timedelta(days=1))
            elif date_filter == "thisWeek":
                start_of_week = today - timedelta(days=today.weekday())
                end_of_week = start_of_week + timedelta(days=6)
                conditions.append("DATE(i.sale_date) BETWEEN %s AND %s")
                params.extend([start_of_week, end_of_week])
            elif date_filter == "lastWeek":
                start_of_last_week = today - timedelta(days=today.weekday() + 7)
                end_of_last_week = start_of_last_week + timedelta(days=6)
                conditions.append("DATE(i.sale_date) BETWEEN %s AND %s")
                params.extend([start_of_last_week, end_of_last_week])
            elif date_filter == "thisMonth":
                start_of_month = today.replace(day=1)
                end_of_month = (today.replace(day=28) + timedelta(days=4)).replace(day=1) - timedelta(days=1)
                conditions.append("DATE(i.sale_date) BETWEEN %s AND %s")
                params.extend([start_of_month, end_of_month])
            elif date_filter == "lastMonth":
                first_day_this_month = today.replace(day=1)
                end_of_last_month = first_day_this_month - timedelta(days=1)
                start_of_last_month = end_of_last_month.replace(day=1)
                conditions.append("DATE(i.sale_date) BETWEEN %s AND %s")
                params.extend([start_of_last_month, end_of_last_month])
            elif date_filter == "customRange" and start_date and end_date:
                try:
                    start = datetime.strptime(start_date, "%Y-%m-%d").date()
                    end = datetime.strptime(end_date, "%Y-%m-%d").date()
                    conditions.append("DATE(i.sale_date) BETWEEN %s AND %s")
                    params.extend([start, end])
                except ValueError:
                    return jsonify({"error": "Invalid custom date format"}), 400

        if conditions:
            base_query += " WHERE " + " AND ".join(conditions)

        base_query += " ORDER BY i.sale_date DESC"
        cursor.execute(base_query, tuple(params))
        invoices = cursor.fetchall()

        if not invoices:
            return jsonify({"invoices": []}), 200

        # ✅ OPTIMIZATION: Fetch all invoice IDs
        invoice_ids = [inv['invoice_id'] for inv in invoices]
        
        # ✅ FIXED: Fetch all payments with JOIN to payment_methods table
        placeholders = ','.join(['%s'] * len(invoice_ids))
        cursor.execute(f"""
            SELECT sp.sale_id, sp.payment_method as method_id, pm.method_name, 
                   sp.amount, sp.payment_date
            FROM sale_payments sp
            INNER JOIN payment_methods pm ON sp.payment_method = pm.id
            WHERE sp.sale_id IN ({placeholders})
            ORDER BY sp.sale_id, sp.payment_date ASC
        """, tuple(invoice_ids))
        
        all_payments = cursor.fetchall()
        
        # Group payments by invoice_id
        payments_by_invoice = {}
        for payment in all_payments:
            invoice_id = payment['sale_id']
            if invoice_id not in payments_by_invoice:
                payments_by_invoice[invoice_id] = []
            payments_by_invoice[invoice_id].append(payment)

        # ✅ OPTIMIZATION: Fetch all products in ONE query
        cursor.execute(f"""
            SELECT sp.invoice_id, p.product_name, sp.price, sp.quantity, sp.total
            FROM sale_products sp
            JOIN products p ON sp.product_id = p.id
            WHERE sp.invoice_id IN ({placeholders})
            ORDER BY sp.invoice_id, sp.sale_product_id
        """, tuple(invoice_ids))
        
        all_products = cursor.fetchall()
        
        # Group products by invoice_id
        products_by_invoice = {}
        for product in all_products:
            invoice_id = product['invoice_id']
            if invoice_id not in products_by_invoice:
                products_by_invoice[invoice_id] = []
            products_by_invoice[invoice_id].append(product)

        # ✅ Build response
        all_invoice_details = []

        for invoice in invoices:
            invoice_id = invoice['invoice_id']
            
            # Get payments for this invoice
            payments = payments_by_invoice.get(invoice_id, [])
            
            # Get products for this invoice
            products = products_by_invoice.get(invoice_id, [])

            invoice_detail = {
                "invoice_id": invoice_id,
                "invoice_code": invoice['invoice_code'],
                "tenderAmount": float(invoice['tenderAmount']) if invoice['tenderAmount'] else 0,
                "remainingAmount": float(invoice['remainingAmount']) if invoice['remainingAmount'] else 0,
                "invoiceTotal": float(invoice['invoiceTotal']) if invoice['invoiceTotal'] else 0,
                "payment_method": invoice['payment_method'],
                "payment_status": invoice.get('payment_status', 'paid'),
                "payment_type": invoice['payment_method'],
                "customer_id": invoice['customer_id'],
                "customer_name": invoice['customer_name'] or 'Walk-in Customer',
                "sale_date": invoice['sale_date'].strftime('%Y-%m-%d %H:%M:%S') if invoice['sale_date'] else '',
                "status": invoice['status'],
                "warehouse_id": invoice['warehouse_id'],
                "warehouse_name": invoice['warehouse_name'] or 'Unknown',
                "store_id": invoice['store_id'],
                "store_name": invoice['store_name'] or 'Unknown',
                "created_by": invoice['created_by_name'] or 'System',
                "created_by_name": invoice['created_by_name'] or 'System',
                "created_at": invoice['created_at'].strftime('%Y-%m-%d %H:%M:%S') if invoice.get('created_at') else '',
                "payments": [
                    {
                        "method_id": p['method_id'],
                        "method": p['method_name'],
                        "amount": float(p['amount']),
                        "date": p['payment_date'].strftime('%Y-%m-%d %H:%M:%S') if p['payment_date'] else ''
                    }
                    for p in payments
                ],
                "payment_count": len(payments),
                "products": [
                    {
                        "product_name": p['product_name'], 
                        "price": float(p['price']) if p['price'] else 0, 
                        "quantity": float(p['quantity']) if p['quantity'] else 0, 
                        "total": float(p['total']) if p['total'] else 0
                    }
                    for p in products
                ]
            }

            all_invoice_details.append(invoice_detail)

        cursor.close()
        conn.close()

        return jsonify({"invoices": all_invoice_details}), 200

    except Exception as e:
        print("❌ Error in view_all_invoices:", e)
        import traceback
        traceback.print_exc()
        return jsonify({"error": str(e)}), 500



# ============================================
# ✅ FIXED: VIEW ALL INVOICES - WITH INTEGER PAYMENT_METHOD
# ============================================
@sale_bp.route('/report_view_all_invoices', methods=['GET'])
@jwt_required()
@role_required('admin', 'cashier')
def report_view_all_invoices():
    """
    Get all invoices with filters - includes payment breakdown.
    ✅ Optimized: Avoids N+1 query problem
    ✅ FIXED: Properly handles INTEGER payment_method FK
    """
    try:
        conn = get_db_connection()
        cursor = conn.cursor(dictionary=True)

        status_filter = request.args.get('status')
        date_filter = request.args.get('dateFilter')
        start_date = request.args.get('startDate')
        end_date = request.args.get('endDate')

        today = datetime.now().date()

        base_query = """
            SELECT 
                i.invoice_id, i.invoice_code, i.tenderAmount, i.remainingAmount, 
                i.invoiceTotal, i.payment_method, i.payment_status,
                i.customer_id, c.name AS customer_name, i.sale_date, i.status,
                i.warehouse_id, w.warehouse_name, i.store_id, s.store_name,
                u.name AS created_by_name, i.created_at
            FROM invoice_sale i
            LEFT JOIN customers c ON i.customer_id = c.id
            LEFT JOIN warehouses w ON i.warehouse_id = w.id
            LEFT JOIN stores s ON i.store_id = s.id
            LEFT JOIN users u ON i.cashier_user_id = u.id
        """

        conditions = []
        params = []

        if status_filter and status_filter.lower() != "all":
            conditions.append("i.status = %s")
            params.append(status_filter)

        if date_filter and date_filter != "allTime":
            if date_filter == "today":
                conditions.append("DATE(i.sale_date) = %s")
                params.append(today)
            elif date_filter == "yesterday":
                conditions.append("DATE(i.sale_date) = %s")
                params.append(today - timedelta(days=1))
            elif date_filter == "thisWeek":
                start_of_week = today - timedelta(days=today.weekday())
                end_of_week = start_of_week + timedelta(days=6)
                conditions.append("DATE(i.sale_date) BETWEEN %s AND %s")
                params.extend([start_of_week, end_of_week])
            elif date_filter == "lastWeek":
                start_of_last_week = today - timedelta(days=today.weekday() + 7)
                end_of_last_week = start_of_last_week + timedelta(days=6)
                conditions.append("DATE(i.sale_date) BETWEEN %s AND %s")
                params.extend([start_of_last_week, end_of_last_week])
            elif date_filter == "thisMonth":
                start_of_month = today.replace(day=1)
                end_of_month = (today.replace(day=28) + timedelta(days=4)).replace(day=1) - timedelta(days=1)
                conditions.append("DATE(i.sale_date) BETWEEN %s AND %s")
                params.extend([start_of_month, end_of_month])
            elif date_filter == "lastMonth":
                first_day_this_month = today.replace(day=1)
                end_of_last_month = first_day_this_month - timedelta(days=1)
                start_of_last_month = end_of_last_month.replace(day=1)
                conditions.append("DATE(i.sale_date) BETWEEN %s AND %s")
                params.extend([start_of_last_month, end_of_last_month])
            elif date_filter == "customRange" and start_date and end_date:
                try:
                    start = datetime.strptime(start_date, "%Y-%m-%d").date()
                    end = datetime.strptime(end_date, "%Y-%m-%d").date()
                    conditions.append("DATE(i.sale_date) BETWEEN %s AND %s")
                    params.extend([start, end])
                except ValueError:
                    return jsonify({"error": "Invalid custom date format"}), 400

        if conditions:
            base_query += " WHERE " + " AND ".join(conditions)

        base_query += " ORDER BY i.sale_date DESC"
        cursor.execute(base_query, tuple(params))
        invoices = cursor.fetchall()

        if not invoices:
            return jsonify({"invoices": []}), 200

        # ✅ OPTIMIZATION: Fetch all invoice IDs
        invoice_ids = [inv['invoice_id'] for inv in invoices]
        
        # ✅ FIXED: Fetch all payments with JOIN to payment_methods table
        placeholders = ','.join(['%s'] * len(invoice_ids))
        cursor.execute(f"""
            SELECT sp.sale_id, sp.payment_method as method_id, pm.method_name, 
                   sp.amount, sp.payment_date
            FROM sale_payments sp
            INNER JOIN payment_methods pm ON sp.payment_method = pm.id
            WHERE sp.sale_id IN ({placeholders})
            ORDER BY sp.sale_id, sp.payment_date ASC
        """, tuple(invoice_ids))
        
        all_payments = cursor.fetchall()
        
        # Group payments by invoice_id
        payments_by_invoice = {}
        for payment in all_payments:
            invoice_id = payment['sale_id']
            if invoice_id not in payments_by_invoice:
                payments_by_invoice[invoice_id] = []
            payments_by_invoice[invoice_id].append(payment)

        # ✅ OPTIMIZATION: Fetch all products in ONE query
        cursor.execute(f"""
            SELECT sp.invoice_id, p.product_name, sp.price, sp.quantity, sp.total
            FROM sale_products sp
            JOIN products p ON sp.product_id = p.id
            WHERE sp.invoice_id IN ({placeholders})
            ORDER BY sp.invoice_id, sp.sale_product_id
        """, tuple(invoice_ids))
        
        all_products = cursor.fetchall()
        
        # Group products by invoice_id
        products_by_invoice = {}
        for product in all_products:
            invoice_id = product['invoice_id']
            if invoice_id not in products_by_invoice:
                products_by_invoice[invoice_id] = []
            products_by_invoice[invoice_id].append(product)

        # ✅ Build response
        all_invoice_details = []

        for invoice in invoices:
            invoice_id = invoice['invoice_id']
            
            # Get payments for this invoice
            payments = payments_by_invoice.get(invoice_id, [])
            
            # Get products for this invoice
            products = products_by_invoice.get(invoice_id, [])

            invoice_detail = {
                "invoice_id": invoice_id,
                "invoice_code": invoice['invoice_code'],
                "tenderAmount": float(invoice['tenderAmount']) if invoice['tenderAmount'] else 0,
                "remainingAmount": float(invoice['remainingAmount']) if invoice['remainingAmount'] else 0,
                "invoiceTotal": float(invoice['invoiceTotal']) if invoice['invoiceTotal'] else 0,
                "payment_method": invoice['payment_method'],
                "payment_status": invoice.get('payment_status', 'paid'),
                "payment_type": invoice['payment_method'],
                "customer_id": invoice['customer_id'],
                "customer_name": invoice['customer_name'] or 'Walk-in Customer',
                "sale_date": invoice['sale_date'].strftime('%Y-%m-%d %H:%M:%S') if invoice['sale_date'] else '',
                "status": invoice['status'],
                "warehouse_id": invoice['warehouse_id'],
                "warehouse_name": invoice['warehouse_name'] or 'Unknown',
                "store_id": invoice['store_id'],
                "store_name": invoice['store_name'] or 'Unknown',
                "created_by": invoice['created_by_name'] or 'System',
                "created_by_name": invoice['created_by_name'] or 'System',
                "created_at": invoice['created_at'].strftime('%Y-%m-%d %H:%M:%S') if invoice.get('created_at') else '',
                "payments": [
                    {
                        "method_id": p['method_id'],
                        "method": p['method_name'],
                        "amount": float(p['amount']),
                        "date": p['payment_date'].strftime('%Y-%m-%d %H:%M:%S') if p['payment_date'] else ''
                    }
                    for p in payments
                ],
                "payment_count": len(payments),
                "products": [
                    {
                        "product_name": p['product_name'], 
                        "price": float(p['price']) if p['price'] else 0, 
                        "quantity": float(p['quantity']) if p['quantity'] else 0, 
                        "total": float(p['total']) if p['total'] else 0
                    }
                    for p in products
                ]
            }

            all_invoice_details.append(invoice_detail)

        cursor.close()
        conn.close()

        return jsonify({"invoices": all_invoice_details}), 200

    except Exception as e:
        print("❌ Error in view_all_invoices:", e)
        import traceback
        traceback.print_exc()
        return jsonify({"error": str(e)}), 500


# ============================================
# ✅ CORRECTED: DASHBOARD TOP SELLING PRODUCTS
# ============================================
@sale_bp.route('/dashboard_top_selling_products', methods=['GET'])
@jwt_required()
@role_required('admin', 'cashier')
def dashboard_top_selling_products():
    """
    Get top 5 selling products for dashboard.
    ✅ Fixed: Added unit_name, variation_type, and proper SKU handling
    """
    try:
        conn = get_db_connection()
        cursor = conn.cursor(dictionary=True)

        cursor.execute("""
            SELECT 
                sp.product_id,
                p.product_name,
                p.sku AS product_sku,
                
                sp.variation_id,
                pv.variation_name,
                pv.variation_type,
                pv.variation_sku,
                
                -- Total quantity sold
                SUM(sp.quantity) AS total_quantity,
                
                -- Total sales value
                SUM(sp.total) AS total_sales,
                
                -- Unit price (max or avg)
                MAX(sp.unit_price) AS unit_price,
                
                -- Sale unit abbreviation
                MAX(sp.sale_unit) AS sale_unit,
                
                -- ✅ Sale unit full name
                u.unit_name AS sale_unit_name
                
            FROM sale_products sp
            JOIN products p ON sp.product_id = p.id
            LEFT JOIN product_variations pv ON sp.variation_id = pv.id
            JOIN invoice_sale i ON sp.invoice_id = i.invoice_id
            LEFT JOIN units u ON p.sale_unit_id = u.id
            WHERE i.status = 'received'
            GROUP BY 
                sp.product_id,
                p.product_name,
                p.sku,
                sp.variation_id,
                pv.variation_name,
                pv.variation_type,
                pv.variation_sku,
                u.unit_name
            ORDER BY total_quantity DESC
            LIMIT 5
        """)

        results = cursor.fetchall()
        top_products = []
        
        for product in results:
            # Determine which SKU to display
            display_sku = product['variation_sku'] if product['variation_id'] else product['product_sku']
            
            product_data = {
                "product_id": product['product_id'],
                "product_name": product['product_name'],
                "product_sku": product['product_sku'],
                "variation_id": product['variation_id'],
                "variation_name": product['variation_name'],
                "variation_type": product['variation_type'],
                "variation_sku": product['variation_sku'],
                "sku": display_sku,  # ✅ Display SKU (variation or product)
                "total_quantity": float(product['total_quantity']) if product['total_quantity'] else 0,
                "total_sales": float(product['total_sales']) if product['total_sales'] else 0,
                "unit_price": float(product['unit_price']) if product['unit_price'] else 0,
                "sale_unit": product['sale_unit'] or '',
                "sale_unit_name": product['sale_unit_name'] or ''
            }
            
            top_products.append(product_data)
        
        cursor.close()
        conn.close()

        return jsonify({
            "message": "Dashboard top selling products fetched successfully",
            "top_products": top_products
        }), 200

    except Exception as e:
        print("❌ Error in dashboard_top_selling_products:", e)
        import traceback
        traceback.print_exc()
        return jsonify({"error": str(e)}), 500

# ============================================
# ENDPOINT 1: Get All Top Selling Products
# ============================================
@sale_bp.route('/top_selling_products', methods=['GET'])
@jwt_required()
@role_required('admin', 'cashier')
def top_selling_products():
    """
    Get ALL selling products ordered by total quantity DESC
    Includes sale_unit_name from units table
    """
    try:
        conn = get_db_connection()
        cursor = conn.cursor(dictionary=True)

        cursor.execute("""
            SELECT 
                sp.product_id,
                p.product_name,
                p.sku AS product_sku,
                
                sp.variation_id,
                pv.variation_name,
                pv.variation_type,
                pv.variation_sku,
                
                SUM(sp.quantity)   AS total_quantity,
                AVG(sp.unit_price) AS avg_price,
                SUM(sp.total)      AS total_sales,
                MAX(sp.sale_unit)  AS sale_unit,
                u.unit_name        AS sale_unit_name

            FROM sale_products sp
            JOIN products p        ON sp.product_id   = p.id
            LEFT JOIN product_variations pv ON sp.variation_id = pv.id
            JOIN invoice_sale i    ON sp.invoice_id   = i.invoice_id
            LEFT JOIN units u      ON p.sale_unit_id  = u.id
            WHERE i.status = 'received'
            GROUP BY
                sp.product_id,
                p.product_name,
                p.sku,
                sp.variation_id,
                pv.variation_name,
                pv.variation_type,
                pv.variation_sku,
                u.unit_name
            ORDER BY total_quantity DESC
        """)

        results = cursor.fetchall()
        cursor.close()
        conn.close()

        top_products = []
        for product in results:
            total_quantity = float(product['total_quantity']) if product['total_quantity'] else 0
            avg_price      = float(product['avg_price'])      if product['avg_price']      else 0
            total_sales    = float(product['total_sales'])    if product['total_sales']    else 0

            display_sku = product['variation_sku'] if product['variation_id'] else product['product_sku']

            top_products.append({
                "product_id":      product['product_id'],
                "product_name":    product['product_name'],
                "product_sku":     product['product_sku'],
                "variation_id":    product['variation_id'],
                "variation_name":  product['variation_name'],
                "variation_type":  product['variation_type'],
                "variation_sku":   product['variation_sku'],
                "sku":             display_sku,
                "avg_price":       avg_price,
                "total_quantity":  total_quantity,
                "sale_unit":       product['sale_unit']      or '',
                "sale_unit_name":  product['sale_unit_name'] or '',
                "total_sales":     total_sales
            })

        return jsonify({
            "message":      "Top selling products fetched successfully",
            "top_products": top_products
        }), 200

    except Exception as e:
        print("❌ Error in top_selling_products:", e)
        import traceback
        traceback.print_exc()
        return jsonify({"error": str(e)}), 500


# ============================================
# ENDPOINT 2: Get Price Breakdown for Product
# ============================================
@sale_bp.route('/product_price_breakdown', methods=['GET'])
@jwt_required()
@role_required('admin', 'cashier')
def product_price_breakdown():
    """
    Get all different prices for a specific product.
    Query params:
        product_id   (required)
        variation_id (optional)
    """
    try:
        product_id   = request.args.get('product_id')
        variation_id = request.args.get('variation_id', None)

        if not product_id:
            return jsonify({"error": "product_id is required"}), 400

        conn   = get_db_connection()
        cursor = conn.cursor(dictionary=True)

        base_query = """
            SELECT
                sp.unit_price,
                SUM(sp.quantity)  AS quantity_sold,
                SUM(sp.total)     AS total_sales,
                MAX(sp.sale_unit) AS sale_unit,
                u.unit_name       AS sale_unit_name
            FROM sale_products sp
            JOIN invoice_sale i ON sp.invoice_id  = i.invoice_id
            JOIN products p     ON sp.product_id  = p.id
            LEFT JOIN units u   ON p.sale_unit_id = u.id
            WHERE sp.product_id = %s
              AND i.status = 'received'
              {variation_filter}
            GROUP BY sp.unit_price, u.unit_name
            ORDER BY quantity_sold DESC
        """

        if variation_id and variation_id != 'null':
            query = base_query.format(variation_filter="AND sp.variation_id = %s")
            cursor.execute(query, (product_id, variation_id))
        else:
            query = base_query.format(variation_filter="AND sp.variation_id IS NULL")
            cursor.execute(query, (product_id,))

        results = cursor.fetchall()
        cursor.close()
        conn.close()

        price_breakdown = []
        for row in results:
            price_breakdown.append({
                "unit_price":     float(row['unit_price'])    if row['unit_price']    else 0,
                "quantity_sold":  float(row['quantity_sold']) if row['quantity_sold'] else 0,
                "total_sales":    float(row['total_sales'])   if row['total_sales']   else 0,
                "sale_unit":      row['sale_unit']      or '',
                "sale_unit_name": row['sale_unit_name'] or ''
            })

        return jsonify({
            "message":         "Price breakdown fetched successfully",
            "price_breakdown": price_breakdown
        }), 200

    except Exception as e:
        print("❌ Error in product_price_breakdown:", e)
        import traceback
        traceback.print_exc()
        return jsonify({"error": str(e)}), 500

# ============================================
# PROFIT VIEW ALL INVOICES
# ✅ our_price column removed — single sp.price used
# ============================================
@sale_bp.route('/profit_view_all_invoices_sales', methods=['GET'])
@jwt_required()
@role_required('admin', 'cashier')
def profit_view_all_invoices_sales():
    """
    Get invoices with profit calculations using sale_product_items cost data.
    ✅ Uses sp.price (single selling price) for profit calculation
    ✅ Accurate profit calculation from FIFO batch costs
    ✅ our_price column removed from sale_products
    """
    try:
        conn   = get_db_connection()
        cursor = conn.cursor(dictionary=True)

        status_filter = request.args.get('status')
        date_filter   = request.args.get('dateFilter')
        start_date    = request.args.get('startDate')
        end_date      = request.args.get('endDate')

        today = datetime.now().date()

        base_query = """
            SELECT i.invoice_id, i.invoice_code, i.tenderAmount, i.remainingAmount,
                   i.invoiceTotal, i.payment_method, i.customer_id,
                   c.name AS customer_name, i.sale_date, i.status
            FROM invoice_sale i
            LEFT JOIN customers c ON i.customer_id = c.id
        """

        conditions = []
        params     = []

        if status_filter and status_filter.lower() != "all":
            conditions.append("i.status = %s")
            params.append(status_filter)

        if date_filter and date_filter != "allTime":
            if date_filter == "today":
                conditions.append("DATE(i.sale_date) = %s")
                params.append(today)
            elif date_filter == "yesterday":
                conditions.append("DATE(i.sale_date) = %s")
                params.append(today - timedelta(days=1))
            elif date_filter == "thisWeek":
                start_of_week = today - timedelta(days=today.weekday())
                end_of_week   = start_of_week + timedelta(days=6)
                conditions.append("DATE(i.sale_date) BETWEEN %s AND %s")
                params.extend([start_of_week, end_of_week])
            elif date_filter == "lastWeek":
                start_of_last_week = today - timedelta(days=today.weekday() + 7)
                end_of_last_week   = start_of_last_week + timedelta(days=6)
                conditions.append("DATE(i.sale_date) BETWEEN %s AND %s")
                params.extend([start_of_last_week, end_of_last_week])
            elif date_filter == "thisMonth":
                start_of_month = today.replace(day=1)
                end_of_month   = (today.replace(day=28) + timedelta(days=4)).replace(day=1) - timedelta(days=1)
                conditions.append("DATE(i.sale_date) BETWEEN %s AND %s")
                params.extend([start_of_month, end_of_month])
            elif date_filter == "lastMonth":
                first_day_this_month = today.replace(day=1)
                end_of_last_month    = first_day_this_month - timedelta(days=1)
                start_of_last_month  = end_of_last_month.replace(day=1)
                conditions.append("DATE(i.sale_date) BETWEEN %s AND %s")
                params.extend([start_of_last_month, end_of_last_month])
            elif date_filter == "customRange" and start_date and end_date:
                try:
                    start = datetime.strptime(start_date, "%Y-%m-%d").date()
                    end   = datetime.strptime(end_date,   "%Y-%m-%d").date()
                    conditions.append("DATE(i.sale_date) BETWEEN %s AND %s")
                    params.extend([start, end])
                except ValueError:
                    return jsonify({"error": "Invalid custom date format"}), 400

        if conditions:
            base_query += " WHERE " + " AND ".join(conditions)

        base_query += " ORDER BY i.sale_date DESC"
        cursor.execute(base_query, tuple(params))
        invoices = cursor.fetchall()

        all_invoice_details = []

        for invoice in invoices:
            # ✅ Single selling price — sp.our_price removed, sp.price used
            cursor.execute("""
                SELECT
                    p.product_name,
                    sp.price                AS selling_price,
                    sp.quantity,
                    sp.total,
                    COALESCE(
                        (SELECT SUM(spi.quantity * spi.cost) / SUM(spi.quantity)
                         FROM sale_product_items spi
                         WHERE spi.sale_product_id = sp.sale_product_id),
                        0
                    ) AS weighted_avg_cost
                FROM sale_products sp
                JOIN products p ON sp.product_id = p.id
                WHERE sp.invoice_id = %s
            """, (invoice['invoice_id'],))
            products = cursor.fetchall()

            products_list    = []
            products_total   = 0
            profit_subtotal  = 0

            for p in products:
                # ✅ Single selling price (our_price removed)
                selling_price = float(p['selling_price'])       if p['selling_price']       else 0.0
                cost          = float(p['weighted_avg_cost'])   if p['weighted_avg_cost']   else 0.0
                quantity      = float(p['quantity'])            if p['quantity']            else 0.0
                total         = float(p['total'])               if p['total']               else 0.0

                # Profit = (selling_price - cost) × quantity
                profit = (selling_price - cost) * quantity
                products_total  += total
                profit_subtotal += profit

                products_list.append({
                    "product_name":  p['product_name'],
                    "selling_price": selling_price,   # ✅ single selling price
                    "price":         selling_price,   # for compatibility
                    "cost":          cost,
                    "quantity":      quantity,
                    "total":         total,
                    "profit":        profit
                })

            all_invoice_details.append({
                "invoice_id":      invoice['invoice_id'],
                "invoice_code":    invoice['invoice_code'],
                "tenderAmount":    float(invoice['tenderAmount'])    if invoice['tenderAmount']    else 0.0,
                "remainingAmount": float(invoice['remainingAmount']) if invoice['remainingAmount'] else 0.0,
                "invoiceTotal":    float(invoice['invoiceTotal'])    if invoice['invoiceTotal']    else 0.0,
                "payment_method":  invoice['payment_method'],
                "customer_id":     invoice['customer_id'],
                "customer_name":   invoice['customer_name'],
                "sale_date":       invoice['sale_date'].strftime('%Y-%m-%d %H:%M:%S') if invoice['sale_date'] else '',
                "status":          invoice['status'],
                "products":        products_list,
                "products_total":  products_total,
                "profit_subtotal": profit_subtotal
            })

        cursor.close()
        conn.close()

        return jsonify({"invoices": all_invoice_details}), 200

    except Exception as e:
        print("❌ Error in profit_view_all_invoices_sales:", e)
        import traceback
        traceback.print_exc()
        return jsonify({"error": str(e)}), 500


# ============================================
# DASHBOARD VIEW INVOICES
# ✅ our_price column removed — single sp.price used
# ============================================
@sale_bp.route('/dashboard_view_all_invoices', methods=['GET'])
@jwt_required()
@role_required('admin', 'cashier')
def dashboard_view_all_invoices():
    """
    Get recent 5 invoices for dashboard.
    ✅ Returns sp.price as single selling price (our_price column removed)
    """
    try:
        conn   = get_db_connection()
        cursor = conn.cursor(dictionary=True)

        status_filter = request.args.get('status')

        base_query = """
            SELECT i.invoice_id, i.invoice_code, i.tenderAmount, i.remainingAmount,
                   i.invoiceTotal, i.payment_method, i.customer_id,
                   c.name AS customer_name, i.sale_date, i.status
            FROM invoice_sale i
            LEFT JOIN customers c ON i.customer_id = c.id
        """

        if status_filter:
            base_query += " WHERE i.status = %s ORDER BY i.sale_date DESC LIMIT 5"
            cursor.execute(base_query, (status_filter,))
        else:
            base_query += " ORDER BY i.sale_date DESC LIMIT 5"
            cursor.execute(base_query)

        invoices           = cursor.fetchall()
        all_invoice_details = []

        for invoice in invoices:
            # ✅ Single selling price — sp.our_price removed, sp.price used
            cursor.execute("""
                SELECT
                    p.product_name,
                    sp.price    AS selling_price,
                    sp.quantity,
                    sp.total
                FROM sale_products sp
                JOIN products p ON sp.product_id = p.id
                WHERE sp.invoice_id = %s
            """, (invoice['invoice_id'],))
            products = cursor.fetchall()

            all_invoice_details.append({
                "invoice_id":      invoice['invoice_id'],
                "invoice_code":    invoice['invoice_code'],
                "tenderAmount":    float(invoice['tenderAmount'])    if invoice['tenderAmount']    else 0.0,
                "remainingAmount": float(invoice['remainingAmount']) if invoice['remainingAmount'] else 0.0,
                "invoiceTotal":    float(invoice['invoiceTotal'])    if invoice['invoiceTotal']    else 0.0,
                "payment_method":  invoice['payment_method'],
                "customer_id":     invoice['customer_id'],
                "customer_name":   invoice['customer_name'],
                "sale_date":       invoice['sale_date'].strftime('%Y-%m-%d %H:%M:%S') if invoice['sale_date'] else '',
                "status":          invoice['status'],
                "products": [
                    {
                        "product_name":  p['product_name'],
                        "selling_price": float(p['selling_price']) if p['selling_price'] else 0.0,
                        "price":         float(p['selling_price']) if p['selling_price'] else 0.0,  # for compatibility
                        "quantity":      float(p['quantity'])      if p['quantity']      else 0.0,
                        "total":         float(p['total'])         if p['total']         else 0.0
                    }
                    for p in products
                ]
            })

        cursor.close()
        conn.close()

        return jsonify({"invoices": all_invoice_details}), 200

    except Exception as e:
        print("❌ Error in dashboard_view_all_invoices:", e)
        import traceback
        traceback.print_exc()
        return jsonify({"error": str(e)}), 500

@sale_bp.route('/get_today_summary', methods=['GET'])
@jwt_required()
@role_required('admin', 'cashier', 'manager')
def get_today_summary():
    """
    ✅ CORRECTLY handles multiple payment methods per invoice
    ✅ FIXED: total_payments_collected now uses actual sale_payments data
    """
    try:
        conn = get_db_connection()
        cursor = conn.cursor(dictionary=True)

        today = date.today()
        start_time = datetime.combine(today, datetime.min.time())
        end_time = datetime.combine(today, datetime.max.time())

        # ============================================
        # 1. TOTAL SALES BY STATUS
        # ============================================
        cursor.execute("""
            SELECT 
                IFNULL(SUM(CASE WHEN status = 'received' THEN invoiceTotal ELSE 0 END), 0) as received_sales,
                IFNULL(SUM(CASE WHEN status = 'suspended' THEN invoiceTotal ELSE 0 END), 0) as suspended_sales,
                IFNULL(SUM(invoiceTotal), 0) as total_sales,
                COUNT(CASE WHEN status = 'received' THEN 1 END) as received_count,
                COUNT(CASE WHEN status = 'suspended' THEN 1 END) as suspended_count
            FROM invoice_sale
            WHERE created_at BETWEEN %s AND %s
        """, (start_time, end_time))
        sales_data = cursor.fetchone()

        # ============================================
        # 2. PAYMENT METHODS BREAKDOWN (✅ FIXED FOR MULTIPLE PAYMENTS)
        # ============================================
        cursor.execute("""
            WITH invoice_payments AS (
                SELECT 
                    sp.sale_id,
                    inv.invoiceTotal,
                    pm.id as payment_method_id,
                    pm.method_name,
                    sp.amount,
                    SUM(sp.amount) OVER (PARTITION BY sp.sale_id ORDER BY sp.id) as running_total,
                    sp.id as payment_id
                FROM sale_payments sp
                INNER JOIN invoice_sale inv ON sp.sale_id = inv.invoice_id
                INNER JOIN payment_methods pm ON sp.payment_method = pm.id
                WHERE inv.created_at BETWEEN %s AND %s
            )
            SELECT 
                payment_method_id,
                method_name,
                COUNT(DISTINCT payment_id) as transaction_count,
                SUM(
                    CASE 
                        -- If running total exceeds invoice, cap this payment
                        WHEN running_total > invoiceTotal THEN 
                            GREATEST(0, invoiceTotal - (running_total - amount))
                        ELSE amount
                    END
                ) as total_amount
            FROM invoice_payments
            GROUP BY payment_method_id, method_name
            ORDER BY total_amount DESC
        """, (start_time, end_time))
        payment_methods = cursor.fetchall()

        # ============================================
        # 3. TOTAL CASH COLLECTED (✅ FIXED)
        # ============================================
        cursor.execute("""
            WITH invoice_payments AS (
                SELECT 
                    sp.sale_id,
                    inv.invoiceTotal,
                    pm.method_name,
                    sp.amount,
                    SUM(sp.amount) OVER (PARTITION BY sp.sale_id ORDER BY sp.id) as running_total
                FROM sale_payments sp
                INNER JOIN invoice_sale inv ON sp.sale_id = inv.invoice_id
                INNER JOIN payment_methods pm ON sp.payment_method = pm.id
                WHERE inv.created_at BETWEEN %s AND %s
                AND LOWER(pm.method_name) = 'cash'
            )
            SELECT IFNULL(SUM(
                CASE 
                    WHEN running_total > invoiceTotal THEN 
                        GREATEST(0, invoiceTotal - (running_total - amount))
                    ELSE amount
                END
            ), 0) as total_cash
            FROM invoice_payments
        """, (start_time, end_time))
        cash_payment_result = cursor.fetchone()
        cash_payment = float(cash_payment_result['total_cash']) if cash_payment_result else 0.0

        # ============================================
        # 4. CARD PAYMENTS (✅ FIXED)
        # ============================================
        cursor.execute("""
            WITH invoice_payments AS (
                SELECT 
                    sp.sale_id,
                    inv.invoiceTotal,
                    pm.method_name,
                    sp.amount,
                    SUM(sp.amount) OVER (PARTITION BY sp.sale_id ORDER BY sp.id) as running_total
                FROM sale_payments sp
                INNER JOIN invoice_sale inv ON sp.sale_id = inv.invoice_id
                INNER JOIN payment_methods pm ON sp.payment_method = pm.id
                WHERE inv.created_at BETWEEN %s AND %s
                AND LOWER(pm.method_name) IN ('card', 'credit card', 'debit card', 'credit/debit card')
            )
            SELECT IFNULL(SUM(
                CASE 
                    WHEN running_total > invoiceTotal THEN 
                        GREATEST(0, invoiceTotal - (running_total - amount))
                    ELSE amount
                END
            ), 0) as total_card
            FROM invoice_payments
        """, (start_time, end_time))
        card_payment_result = cursor.fetchone()
        card_payment = float(card_payment_result['total_card']) if card_payment_result else 0.0

        # ============================================
        # 5. SALE RETURNS
        # ============================================
        cursor.execute("""
            SELECT 
                IFNULL(SUM(total_return_amount), 0) as total_return,
                COUNT(*) as return_count
            FROM sale_return
            WHERE created_at BETWEEN %s AND %s
        """, (start_time, end_time))
        return_data = cursor.fetchone()

        # ============================================
        # 6. EXPENSES
        # ============================================
        cursor.execute("""
            SELECT 
                IFNULL(SUM(amount), 0) as total_expense,
                COUNT(*) as expense_count
            FROM expenses
            WHERE created_at BETWEEN %s AND %s
        """, (start_time, end_time))
        expense_data = cursor.fetchone()

        # ============================================
        # 7. TOTAL PAYMENTS COLLECTED (✅ FIXED - NOW USES ACTUAL PAYMENTS)
        # ============================================
        cursor.execute("""
            WITH invoice_payments AS (
                SELECT 
                    sp.sale_id,
                    inv.invoiceTotal,
                    sp.amount,
                    SUM(sp.amount) OVER (PARTITION BY sp.sale_id ORDER BY sp.id) as running_total
                FROM sale_payments sp
                INNER JOIN invoice_sale inv ON sp.sale_id = inv.invoice_id
                WHERE inv.created_at BETWEEN %s AND %s
            )
            SELECT IFNULL(SUM(
                CASE 
                    -- If running total exceeds invoice, cap this payment
                    WHEN running_total > invoiceTotal THEN 
                        GREATEST(0, invoiceTotal - (running_total - amount))
                    ELSE amount
                END
            ), 0) as total_payments
            FROM invoice_payments
        """, (start_time, end_time))
        total_payments_result = cursor.fetchone()
        total_payments_collected = float(total_payments_result['total_payments']) if total_payments_result else 0.0

        # ============================================
        # 8. PAYMENT STATUS BREAKDOWN
        # ============================================
        cursor.execute("""
            SELECT 
                payment_status,
                COUNT(*) as count,
                IFNULL(SUM(invoiceTotal), 0) as total_amount,
                IFNULL(SUM(LEAST(tenderAmount, invoiceTotal)), 0) as paid_amount,
                IFNULL(SUM(remainingAmount), 0) as due_amount
            FROM invoice_sale
            WHERE created_at BETWEEN %s AND %s
            GROUP BY payment_status
        """, (start_time, end_time))
        payment_status_breakdown = cursor.fetchall()

        cursor.close()
        conn.close()

        # ============================================
        # FINAL CALCULATIONS
        # ============================================
        total_sale = float(sales_data['total_sales'])
        received_sales = float(sales_data['received_sales'])
        suspended_sales = float(sales_data['suspended_sales'])
        sale_return = float(return_data['total_return'])
        total_expense = float(expense_data['total_expense'])
        
        # Net cash = Cash collected - Returns - Expenses
        net_cash = cash_payment - sale_return - total_expense
        
        # Total cash in hand = Cash collected - Returns (before expenses)
        total_cash_in_hand = cash_payment - sale_return

        return jsonify({
            "date": today.strftime("%Y-%m-%d"),
            "total_sale": total_sale,
            "received_sales": received_sales,
            "suspended_sales": suspended_sales,
            "received_count": sales_data['received_count'],
            "suspended_count": sales_data['suspended_count'],
            "payment_methods": payment_methods,
            "payment_status": payment_status_breakdown,
            "cash_payment": cash_payment,
            "card_payment": card_payment,
            "total_payments_collected": total_payments_collected,
            "sale_return": sale_return,
            "return_count": return_data['return_count'],
            "total_expense": total_expense,
            "expense_count": expense_data['expense_count'],
            "net_cash": net_cash,
            "total_cash_in_hand": total_cash_in_hand
        }), 200

    except Exception as e:
        print(f"❌ Error in today summary: {e}")
        import traceback
        traceback.print_exc()
        return jsonify({"error": str(e)}), 500
    
    
# ============================================
# ✅ CORRECTED: GET REGISTER STATUS
# ============================================
@sale_bp.route('/get_register_status', methods=['GET'])
@jwt_required()
@role_required('admin', 'cashier', 'manager')
def get_register_status():
    """
    ✅ FIXED: Uses actual database columns from schema
    ✅ FIXED: Corrected date filtering for expenses
    """
    try:
        user_id = request.args.get('user_id')
        
        if not user_id:
            return jsonify({
                'success': False, 
                'error': 'Missing user_id parameter'
            }), 400

        conn = get_db_connection()
        cursor = conn.cursor(dictionary=True)
        
        # Query using ACTUAL columns from your schema
        cursor.execute("""
            SELECT 
                id,
                user_id,
                store_id,
                login_time,
                logout_time,
                status,
                COALESCE(opening_balance, 0) as opening_cash,
                COALESCE(total_cash, 0) as total_cash,
                COALESCE(total_card_sales, 0) as total_card_sales,
                COALESCE(total_sale_amount, 0) as total_sale_amount,
                COALESCE(sale_return_amount, 0) as sale_return_amount,
                COALESCE(total_expense, 0) as total_expense,
                COALESCE(expected_cash, 0) as expected_cash,
                COALESCE(actual_cash, 0) as actual_cash,
                COALESCE(cash_difference, 0) as cash_difference,
                COALESCE(closing_balance, 0) as closing_balance,
                notes
            FROM close_register_logs
            WHERE user_id = %s 
              AND status = 'open'
            ORDER BY id DESC 
            LIMIT 1
        """, (user_id,))
        
        register = cursor.fetchone()

        if not register:
            cursor.close()
            conn.close()
            return jsonify({
                'success': False, 
                'message': 'No open register found for this user'
            }), 404

        register_id = register['id']
        login_time = register['login_time']
        opening_cash = float(register['opening_cash'] or 0)
        store_id = register['store_id']
        end_time = datetime.now()

        # ============================================
        # GET CASH TRANSACTIONS
        # ============================================
        cursor.execute("""
            SELECT 
                COALESCE(SUM(CASE WHEN transaction_type = 'add' THEN amount ELSE 0 END), 0) as cash_added,
                COALESCE(SUM(CASE WHEN transaction_type = 'remove' THEN amount ELSE 0 END), 0) as cash_removed
            FROM cash_transactions
            WHERE register_log_id = %s
        """, (register_id,))
        
        cash_trans = cursor.fetchone()
        cash_added = float(cash_trans['cash_added']) if cash_trans else 0
        cash_removed = float(cash_trans['cash_removed']) if cash_trans else 0

        # ============================================
        # GET CASH SALES - Check payment_method column type
        # ============================================
        cursor.execute("""
            SELECT DATA_TYPE 
            FROM INFORMATION_SCHEMA.COLUMNS 
            WHERE TABLE_NAME = 'sale_payments' 
            AND COLUMN_NAME = 'payment_method'
            AND TABLE_SCHEMA = DATABASE()
        """)
        
        column_info = cursor.fetchone()
        is_int_payment_method = column_info and column_info['DATA_TYPE'] in ('int', 'bigint', 'smallint')
        
        if is_int_payment_method:
            # INT version - FK to payment_methods table
            cursor.execute("""
                SELECT COALESCE(SUM(
                    CASE 
                        WHEN running_total > invoiceTotal THEN 
                            GREATEST(0, invoiceTotal - (running_total - amount))
                        ELSE amount
                    END
                ), 0) as cash_sales
                FROM (
                    SELECT 
                        sp.sale_id,
                        inv.invoiceTotal,
                        sp.amount,
                        pm.method_name,
                        SUM(sp.amount) OVER (PARTITION BY sp.sale_id ORDER BY sp.id) as running_total
                    FROM sale_payments sp
                    INNER JOIN invoice_sale inv ON sp.sale_id = inv.invoice_id
                    LEFT JOIN payment_methods pm ON sp.payment_method = pm.id
                    WHERE inv.cashier_user_id = %s
                    AND inv.sale_date BETWEEN %s AND %s
                    AND LOWER(pm.method_name) = 'cash'
                ) sub
            """, (user_id, login_time, end_time))
        else:
            # VARCHAR/ENUM version - direct string comparison
            cursor.execute("""
                SELECT COALESCE(SUM(
                    CASE 
                        WHEN running_total > invoiceTotal THEN 
                            GREATEST(0, invoiceTotal - (running_total - amount))
                        ELSE amount
                    END
                ), 0) as cash_sales
                FROM (
                    SELECT 
                        sp.sale_id,
                        inv.invoiceTotal,
                        sp.amount,
                        SUM(sp.amount) OVER (PARTITION BY sp.sale_id ORDER BY sp.id) as running_total
                    FROM sale_payments sp
                    INNER JOIN invoice_sale inv ON sp.sale_id = inv.invoice_id
                    WHERE inv.cashier_user_id = %s
                    AND inv.sale_date BETWEEN %s AND %s
                    AND LOWER(sp.payment_method) = 'cash'
                ) sub
            """, (user_id, login_time, end_time))
        
        cash_sales_result = cursor.fetchone()
        cash_sales = float(cash_sales_result['cash_sales']) if cash_sales_result else 0

        # ============================================
        # ✅ FIXED: GET EXPENSES - Using created_at instead of date
        # ============================================
        cursor.execute("""
            SELECT COALESCE(SUM(amount), 0) as total_expense
            FROM expenses
            WHERE created_by = %s
            AND created_at BETWEEN %s AND %s
            AND store_id = %s
        """, (user_id, login_time, end_time, store_id))
        
        expense_result = cursor.fetchone()
        total_expense = float(expense_result['total_expense']) if expense_result else 0

        # ============================================
        # GET SALE RETURNS
        # ============================================
        cursor.execute("""
            SELECT COALESCE(SUM(total_return_amount), 0) as total_return
            FROM sale_return
            WHERE return_date BETWEEN %s AND %s
            AND store_id = %s
        """, (login_time, end_time, store_id))
        
        return_result = cursor.fetchone()
        total_return = float(return_result['total_return']) if return_result else 0

        cursor.close()
        conn.close()

        # ============================================
        # CALCULATE CURRENT CASH IN HAND
        # ============================================
        current_cash_in_hand = (
            opening_cash + 
            cash_sales + 
            cash_added - 
            cash_removed - 
            total_expense - 
            total_return
        )

        expected_cash = current_cash_in_hand

        return jsonify({
            'success': True,
            'data': {
                'register_id': register_id,
                'user_id': register['user_id'],
                'store_id': store_id,
                'login_time': login_time.strftime('%Y-%m-%d %H:%M:%S') if login_time else None,
                'status': register['status'],
                
                # Cash breakdown
                'opening_cash': opening_cash,
                'cash_sales': cash_sales,
                'cash_added': cash_added,
                'cash_removed': cash_removed,
                'expenses': total_expense,
                'returns': total_return,
                
                # Current totals
                'cash_in_hand': current_cash_in_hand,
                'expected_cash': expected_cash,
                
                # Additional info from register
                'total_sale_amount': float(register.get('total_sale_amount') or 0),
                'total_card_sales': float(register.get('total_card_sales') or 0),
                'actual_cash': float(register.get('actual_cash') or 0),
                'cash_difference': float(register.get('cash_difference') or 0),
                'notes': register.get('notes')
            }
        }), 200

    except Exception as e:
        print(f"❌ Error in get_register_status: {e}")
        import traceback
        traceback.print_exc()
        return jsonify({
            'success': False, 
            'error': str(e)
        }), 500

# ============================================
# FIXED GET CASH SUMMARY - CORRECT EXPECTED CASH
# ============================================

@sale_bp.route('/get_cash_summary', methods=['GET'])
@jwt_required()
@role_required('admin', 'cashier', 'manager')
def get_cash_summary():
    """
    ✅ FIXED: Expected Cash = Opening + Cash Received - Expenses - Returns
    ✅ Does NOT include suspended/unpaid sales
    """
    try:
        user_id = request.args.get('user_id')
        
        if not user_id:
            return jsonify({
                'success': False,
                'error': 'Missing user_id parameter'
            }), 400

        conn = get_db_connection()
        cursor = conn.cursor(dictionary=True)

        # ============================================
        # STEP 1: GET OPEN REGISTER
        # ============================================
        cursor.execute("""
            SELECT 
                id,
                user_id,
                store_id,
                login_time,
                logout_time,
                status,
                COALESCE(opening_balance, 0) as opening_cash,
                COALESCE(total_cash, 0) as total_cash,
                COALESCE(total_card_sales, 0) as total_card_sales,
                COALESCE(total_sale_amount, 0) as total_sale_amount,
                COALESCE(sale_return_amount, 0) as sale_return_amount,
                COALESCE(total_expense, 0) as total_expense,
                COALESCE(expected_cash, 0) as expected_cash,
                COALESCE(actual_cash, 0) as counted_cash,
                COALESCE(cash_difference, 0) as cash_difference,
                COALESCE(closing_balance, 0) as closing_balance,
                notes
            FROM close_register_logs
            WHERE user_id = %s 
              AND status = 'open'
            ORDER BY id DESC 
            LIMIT 1
        """, (user_id,))
        
        register = cursor.fetchone()

        if not register:
            cursor.close()
            conn.close()
            return jsonify({
                'success': False,
                'error': 'No open register found for this user'
            }), 404

        register_id = register['id']
        store_id = register['store_id']
        login_time = register['login_time']
        logout_time = register['logout_time']
        status = register['status']
        
        end_time = logout_time if logout_time else datetime.now()
        opening_cash = float(register['opening_cash'] or 0)
        counted_cash = float(register['counted_cash'] or 0)

        # ============================================
        # STEP 2: MANUAL CASH TRANSACTIONS
        # ============================================
        cursor.execute("""
            SELECT 
                transaction_type,
                COUNT(*) as count,
                COALESCE(SUM(amount), 0) as total
            FROM cash_transactions
            WHERE register_log_id = %s
              AND transaction_type IN ('add', 'remove')
            GROUP BY transaction_type
        """, (register_id,))
        
        cash_transactions = cursor.fetchall()
        cash_added_manual = 0
        cash_removed_manual = 0
        
        for trans in cash_transactions:
            trans_type = trans['transaction_type']
            amount = float(trans['total'] or 0)
            
            if trans_type == 'add':
                cash_added_manual += amount
            elif trans_type == 'remove':
                cash_removed_manual += amount

        # ============================================
        # STEP 3: SALES BREAKDOWN WITH SUSPENDED DETAILS
        # ============================================
        cursor.execute("""
            SELECT 
                COUNT(*) as total_invoices,
                COALESCE(SUM(invoiceTotal), 0) as total_sale,
                
                -- Received sales (fully paid)
                COALESCE(SUM(CASE WHEN status = 'received' THEN invoiceTotal ELSE 0 END), 0) as received_sale,
                COUNT(CASE WHEN status = 'received' THEN 1 END) as received_count,
                
                -- Total suspended sales (invoice totals)
                COALESCE(SUM(CASE WHEN status = 'suspended' THEN invoiceTotal ELSE 0 END), 0) as suspended_sale,
                COUNT(CASE WHEN status = 'suspended' THEN 1 END) as suspended_count,
                
                -- Suspended - Unpaid (Full invoice total, no payments)
                COALESCE(SUM(CASE 
                    WHEN status = 'suspended' AND payment_status = 'unpaid' 
                    THEN invoiceTotal ELSE 0 
                END), 0) as suspended_unpaid_total,
                COUNT(CASE 
                    WHEN status = 'suspended' AND payment_status = 'unpaid' 
                    THEN 1 
                END) as suspended_unpaid_count,
                
                -- Suspended - Partial (Invoice total)
                COALESCE(SUM(CASE 
                    WHEN status = 'suspended' AND payment_status = 'partial' 
                    THEN invoiceTotal ELSE 0 
                END), 0) as suspended_partial_total,
                COUNT(CASE 
                    WHEN status = 'suspended' AND payment_status = 'partial' 
                    THEN 1 
                END) as suspended_partial_count
                
            FROM invoice_sale
            WHERE cashier_user_id = %s 
            AND sale_date BETWEEN %s AND %s
        """, (user_id, login_time, end_time))
        
        sale_data = cursor.fetchone()
        
        total_sale = float(sale_data['total_sale'])
        received_sale = float(sale_data['received_sale'])
        suspended_sale = float(sale_data['suspended_sale'])
        
        # Unpaid suspended sales - Full invoice total is the due amount
        suspended_unpaid_due = float(sale_data['suspended_unpaid_total'] or 0)
        suspended_unpaid_count = sale_data['suspended_unpaid_count'] or 0

        # ============================================
        # STEP 3B: CALCULATE ACTUAL DUE FOR PARTIAL PAYMENTS
        # ============================================
        cursor.execute("""
            SELECT 
                inv.invoice_id,
                inv.invoiceTotal,
                COALESCE(SUM(sp.amount), 0) as total_paid,
                (inv.invoiceTotal - COALESCE(SUM(sp.amount), 0)) as due_amount
            FROM invoice_sale inv
            LEFT JOIN sale_payments sp ON inv.invoice_id = sp.sale_id
            WHERE inv.cashier_user_id = %s
            AND inv.sale_date BETWEEN %s AND %s
            AND inv.status = 'suspended'
            AND inv.payment_status = 'partial'
            GROUP BY inv.invoice_id, inv.invoiceTotal
        """, (user_id, login_time, end_time))
        
        partial_invoices = cursor.fetchall()
        suspended_partial_due = 0
        suspended_partial_count = len(partial_invoices)
        
        for invoice in partial_invoices:
            due_amount = float(invoice['due_amount'] or 0)
            suspended_partial_due += due_amount
            
            print(f"  📊 Partial Invoice #{invoice['invoice_id']}: "
                  f"Total: Rs. {invoice['invoiceTotal']}, "
                  f"Paid: Rs. {invoice['total_paid']}, "
                  f"Due: Rs. {due_amount}")

        # Calculate total actual due amount
        total_suspended_due = suspended_unpaid_due + suspended_partial_due

        print(f"\n📊 === SUSPENDED SALES DUE AMOUNTS ===")
        print(f"  Unpaid Due: Rs. {suspended_unpaid_due:.2f} ({suspended_unpaid_count} sales)")
        print(f"  Partial Due: Rs. {suspended_partial_due:.2f} ({suspended_partial_count} sales)")
        print(f"  💰 TOTAL DUE: Rs. {total_suspended_due:.2f}")

        # ============================================
        # STEP 4: CHECK PAYMENT_METHOD COLUMN TYPE
        # ============================================
        cursor.execute("""
            SELECT DATA_TYPE 
            FROM INFORMATION_SCHEMA.COLUMNS 
            WHERE TABLE_NAME = 'sale_payments' 
            AND COLUMN_NAME = 'payment_method'
            AND TABLE_SCHEMA = DATABASE()
        """)
        
        column_info = cursor.fetchone()
        is_int_payment_method = column_info and column_info['DATA_TYPE'] in ('int', 'bigint', 'smallint')

        # ============================================
        # STEP 5: PAYMENT BREAKDOWN (ACTUAL CASH RECEIVED)
        # ============================================
        if is_int_payment_method:
            cursor.execute("""
                WITH invoice_payments AS (
                    SELECT 
                        sp.sale_id,
                        inv.invoiceTotal,
                        pm.method_name as payment_method,
                        sp.amount,
                        SUM(sp.amount) OVER (PARTITION BY sp.sale_id ORDER BY sp.id) as running_total
                    FROM sale_payments sp
                    INNER JOIN invoice_sale inv ON sp.sale_id = inv.invoice_id
                    LEFT JOIN payment_methods pm ON sp.payment_method = pm.id
                    WHERE inv.cashier_user_id = %s 
                    AND inv.sale_date BETWEEN %s AND %s
                )
                SELECT 
                    payment_method,
                    COUNT(DISTINCT sale_id) as transaction_count,
                    SUM(
                        CASE 
                            WHEN running_total > invoiceTotal THEN 
                                GREATEST(0, invoiceTotal - (running_total - amount))
                            ELSE amount
                        END
                    ) as total_amount
                FROM invoice_payments
                GROUP BY payment_method
                ORDER BY total_amount DESC
            """, (user_id, login_time, end_time))
        else:
            cursor.execute("""
                WITH invoice_payments AS (
                    SELECT 
                        sp.sale_id,
                        inv.invoiceTotal,
                        sp.payment_method,
                        sp.amount,
                        SUM(sp.amount) OVER (PARTITION BY sp.sale_id ORDER BY sp.id) as running_total
                    FROM sale_payments sp
                    INNER JOIN invoice_sale inv ON sp.sale_id = inv.invoice_id
                    WHERE inv.cashier_user_id = %s 
                    AND inv.sale_date BETWEEN %s AND %s
                )
                SELECT 
                    payment_method,
                    COUNT(DISTINCT sale_id) as transaction_count,
                    SUM(
                        CASE 
                            WHEN running_total > invoiceTotal THEN 
                                GREATEST(0, invoiceTotal - (running_total - amount))
                            ELSE amount
                        END
                    ) as total_amount
                FROM invoice_payments
                GROUP BY payment_method
                ORDER BY total_amount DESC
            """, (user_id, login_time, end_time))
        
        payment_methods = cursor.fetchall()
        
        cash_payment_total = 0
        card_payment_total = 0
        other_payment_total = 0
        payment_breakdown = {}
        
        for pm in payment_methods:
            method = str(pm['payment_method']) if pm['payment_method'] else 'unknown'
            amount = float(pm['total_amount'] or 0)
            payment_breakdown[method] = amount
            
            if method.lower() == 'cash':
                cash_payment_total += amount
            elif method.lower() in ('card', 'credit card', 'debit card', 'credit/debit card'):
                card_payment_total += amount
            else:
                other_payment_total += amount

        # ============================================
        # STEP 6: SALE RETURNS
        # ============================================
        cursor.execute("""
            SELECT 
                COUNT(*) as return_count,
                COALESCE(SUM(total_return_amount), 0) as total_return
            FROM sale_return
            WHERE return_date BETWEEN %s AND %s
            AND store_id = %s
        """, (login_time, end_time, store_id))
        
        return_data = cursor.fetchone()
        total_return_amount = float(return_data['total_return'] or 0)
        return_count = return_data['return_count'] or 0

        # ============================================
        # STEP 7: EXPENSES
        # ============================================
        cursor.execute("""
            SELECT 
                COUNT(*) as expense_count,
                COALESCE(SUM(amount), 0) as total_expense
            FROM expenses
            WHERE date BETWEEN %s AND %s
            AND (store_id = %s OR created_by = %s)
        """, (login_time, end_time, store_id, user_id))
        
        expense_data = cursor.fetchone()
        total_expense = float(expense_data['total_expense'] or 0)
        expense_count = expense_data['expense_count'] or 0

        # ============================================
        # STEP 8: PAYMENT STATUS BREAKDOWN
        # ============================================
        cursor.execute("""
            SELECT 
                payment_status,
                COUNT(*) as count,
                COALESCE(SUM(invoiceTotal), 0) as total_amount
            FROM invoice_sale
            WHERE cashier_user_id = %s
            AND sale_date BETWEEN %s AND %s
            GROUP BY payment_status
        """, (user_id, login_time, end_time))
        
        payment_status_breakdown = cursor.fetchall()

        cursor.close()
        conn.close()

        # ============================================
        # STEP 9: ✅ CORRECT EXPECTED CASH CALCULATION
        # ============================================
        # Expected Cash = Opening Cash + Cash Payments RECEIVED + Manual Adds - Manual Removes - Expenses - Returns
        # NOTE: This is ACTUAL CASH that should be in the register
        # It does NOT include suspended/unpaid amounts (those are still owed)
        
        total_payment = cash_payment_total + card_payment_total + other_payment_total
        
        expected_cash = (
            opening_cash +                    # Starting cash
            cash_payment_total +              # Cash payments actually RECEIVED
            cash_added_manual -               # Manual cash added
            cash_removed_manual -             # Manual cash removed
            total_expense -                   # Expenses paid out
            total_return_amount               # Returns refunded
        )
        
        current_cash_in_hand = expected_cash  # Same as expected cash
        cash_difference = counted_cash - expected_cash if counted_cash > 0 else 0

        print(f"\n💰 === EXPECTED CASH CALCULATION ===")
        print(f"  Opening Cash:        + Rs. {opening_cash:.2f}")
        print(f"  Cash Received:       + Rs. {cash_payment_total:.2f}")
        print(f"  Manual Cash Added:   + Rs. {cash_added_manual:.2f}")
        print(f"  Manual Cash Removed: - Rs. {cash_removed_manual:.2f}")
        print(f"  Expenses:            - Rs. {total_expense:.2f}")
        print(f"  Returns:             - Rs. {total_return_amount:.2f}")
        print(f"  ══════════════════════════════════")
        print(f"  EXPECTED CASH:       = Rs. {expected_cash:.2f}")
        print(f"\n  💡 NOTE: Suspended sales (Rs. {total_suspended_due:.2f}) are NOT included")
        print(f"           because that money hasn't been received yet")

        # ============================================
        # STEP 10: RETURN RESPONSE
        # ============================================
        return jsonify({
            "success": True,
            
            # Register info
            "register_id": register_id,
            "status": status,
            "login_time": login_time.strftime('%Y-%m-%d %H:%M:%S') if login_time else None,
            "logout_time": logout_time.strftime('%Y-%m-%d %H:%M:%S') if logout_time else None,
            "store_id": store_id,
            
            # Opening balance & manual transactions
            "opening_cash": opening_cash,
            "cash_added_manual": cash_added_manual,
            "cash_removed_manual": cash_removed_manual,
            
            # Sales breakdown
            "total_invoices": sale_data['total_invoices'],
            "total_sale": total_sale,
            "received_sale": received_sale,
            "suspended_sale": suspended_sale,
            "received_count": sale_data['received_count'],
            "suspended_count": sale_data['suspended_count'],
            
            # Suspended breakdown with ACTUAL DUE AMOUNTS
            "suspended_unpaid_due": suspended_unpaid_due,
            "suspended_unpaid_count": suspended_unpaid_count,
            "suspended_partial_due": suspended_partial_due,
            "suspended_partial_count": suspended_partial_count,
            "total_suspended_due": total_suspended_due,
            
            # Payment breakdown (ACTUAL RECEIVED)
            "cash_payment": cash_payment_total,
            "card_payment": card_payment_total,
            "other_payment": other_payment_total,
            "total_payment": total_payment,
            "payment_breakdown": payment_breakdown,
            "payment_methods": [
                {
                    "method": str(pm['payment_method']) if pm['payment_method'] else 'unknown',
                    "count": pm['transaction_count'],
                    "amount": float(pm['total_amount'] or 0)
                }
                for pm in payment_methods
            ],
            
            # Payment status
            "payment_status": [
                {
                    "status": ps['payment_status'],
                    "count": ps['count'],
                    "amount": float(ps['total_amount'] or 0)
                }
                for ps in payment_status_breakdown
            ],
            
            # Deductions
            "total_return_amount": total_return_amount,
            "return_count": return_count,
            "total_expense": total_expense,
            "expense_count": expense_count,
            
            # ✅ CORRECTED CASH CALCULATIONS
            "cash_in_hand": current_cash_in_hand,      # Current cash that should be in register
            "expected_cash": expected_cash,             # Expected cash (same as above)
            "counted_cash": counted_cash,               # Actual counted cash (if closing)
            "cash_difference": cash_difference,         # Difference (over/short)
            
            # Notes
            "notes": register.get('notes')
            
        }), 200

    except Exception as e:
        print(f"❌ Error in get_cash_summary: {e}")
        import traceback
        traceback.print_exc()
        return jsonify({
            "success": False,
            "error": str(e)
        }), 500
        
@sale_bp.route('/close_register', methods=['POST'])    
@jwt_required()
@role_required('admin', 'cashier', 'manager')
def close_register():
    """
    ✅ FIXED: Matches ACTUAL database schema columns
    Database columns:
    - opening_cash (not cash_in_hand)
    - counted_cash (not cash_in_hand_while_closing)
    - total_cash_added
    - total_cash_removed
    - opening_balance
    - closing_balance
    - expected_cash
    - actual_cash
    - cash_difference
    """
    try:
        data = request.json
        
        print("\n" + "="*50)
        print("📥 CLOSE REGISTER REQUEST RECEIVED")
        print("="*50)
        print("📦 Full Payload:")
        print(json.dumps(data, indent=2, default=str))
        
        # ============================================
        # VALIDATE REQUIRED FIELDS
        # ============================================
        user_id = data.get('user_id')
        
        if not user_id:
            return jsonify({
                'success': False,
                'error': 'Missing user_id'
            }), 400
        
        conn = get_db_connection()
        cursor = conn.cursor(dictionary=True)
        
        # ============================================
        # GET OPEN REGISTER
        # ============================================
        cursor.execute("""
            SELECT 
                id, 
                store_id, 
                opening_balance,
                opening_cash,
                login_time,
                status
            FROM close_register_logs
            WHERE user_id = %s 
              AND status = 'open'
              AND logout_time IS NULL
            ORDER BY id DESC
            LIMIT 1
        """, (user_id,))
        
        register = cursor.fetchone()
        
        if not register:
            cursor.close()
            conn.close()
            return jsonify({
                'success': False,
                'error': 'No open register found for this user'
            }), 404
        
        register_id = register['id']
        store_id = register['store_id']
        opening_balance = float(register['opening_balance'] or 0)
        opening_cash = float(register['opening_cash'] or 0)
        login_time = register['login_time']
        end_time = datetime.now()
        
        print(f"\n✅ Found open register:")
        print(f"   Register ID: {register_id}")
        print(f"   Store ID: {store_id}")
        print(f"   Opening Balance: Rs. {opening_balance}")
        print(f"   Opening Cash: Rs. {opening_cash}")
        print(f"   Login Time: {login_time}")
        
        # ============================================
        # EXTRACT FRONTEND DATA
        # ============================================
        # Frontend sends these fields:
        cash_in_hand = float(data.get('cash_in_hand', 0))  # Expected cash
        cash_in_hand_while_closing = float(data.get('cash_in_hand_while_closing', 0))  # Actual counted
        total_sale_amount = float(data.get('total_sale_amount', 0))
        total_payment = float(data.get('total_payment', 0))
        sale_return_amount = float(data.get('sale_return_amount', 0))
        total_expense = float(data.get('total_expense', 0))
        notes = data.get('notes', '')
        
        # ✅ Map to database column names
        expected_cash = cash_in_hand  # What should be in register
        counted_cash = cash_in_hand_while_closing  # What was actually counted
        
        print(f"\n💰 Cash Amounts:")
        print(f"   Expected Cash: Rs. {expected_cash}")
        print(f"   Counted Cash: Rs. {counted_cash}")
        print(f"   Total Sale Amount: Rs. {total_sale_amount}")
        print(f"   Total Payment: Rs. {total_payment}")
        print(f"   Sale Returns: Rs. {sale_return_amount}")
        print(f"   Expenses: Rs. {total_expense}")
        
        # ============================================
        # GET CASH TRANSACTIONS (added/removed)
        # ============================================
        cursor.execute("""
            SELECT 
                COALESCE(SUM(CASE WHEN transaction_type = 'add' THEN amount ELSE 0 END), 0) as total_added,
                COALESCE(SUM(CASE WHEN transaction_type = 'remove' THEN amount ELSE 0 END), 0) as total_removed
            FROM cash_transactions
            WHERE register_log_id = %s
        """, (register_id,))
        
        cash_trans = cursor.fetchone()
        total_cash_added = float(cash_trans['total_added'] or 0)
        total_cash_removed = float(cash_trans['total_removed'] or 0)
        
        print(f"   Cash Added (Manual): Rs. {total_cash_added}")
        print(f"   Cash Removed: Rs. {total_cash_removed}")
        
        # ============================================
        # CALCULATE CARD SALES
        # ============================================
        # Get payment method column type
        cursor.execute("""
            SELECT DATA_TYPE 
            FROM INFORMATION_SCHEMA.COLUMNS 
            WHERE TABLE_NAME = 'sale_payments' 
            AND COLUMN_NAME = 'payment_method'
            AND TABLE_SCHEMA = DATABASE()
        """)
        
        column_info = cursor.fetchone()
        is_int_payment_method = column_info and column_info['DATA_TYPE'] in ('int', 'bigint', 'smallint')
        
        if is_int_payment_method:
            # INT version - FK to payment_methods
            cursor.execute("""
                SELECT COALESCE(SUM(
                    CASE 
                        WHEN running_total > invoiceTotal THEN 
                            GREATEST(0, invoiceTotal - (running_total - amount))
                        ELSE amount
                    END
                ), 0) as total_card
                FROM (
                    SELECT 
                        sp.sale_id,
                        inv.invoiceTotal,
                        sp.amount,
                        pm.method_name,
                        SUM(sp.amount) OVER (PARTITION BY sp.sale_id ORDER BY sp.id) as running_total
                    FROM sale_payments sp
                    INNER JOIN invoice_sale inv ON sp.sale_id = inv.invoice_id
                    INNER JOIN payment_methods pm ON sp.payment_method = pm.id
                    WHERE inv.cashier_user_id = %s
                      AND inv.sale_date BETWEEN %s AND %s
                      AND LOWER(pm.method_name) IN ('card', 'credit card', 'debit card', 'credit/debit card')
                ) sub
            """, (user_id, login_time, end_time))
        else:
            # VARCHAR/ENUM version
            cursor.execute("""
                SELECT COALESCE(SUM(
                    CASE 
                        WHEN running_total > invoiceTotal THEN 
                            GREATEST(0, invoiceTotal - (running_total - amount))
                        ELSE amount
                    END
                ), 0) as total_card
                FROM (
                    SELECT 
                        sp.sale_id,
                        inv.invoiceTotal,
                        sp.amount,
                        SUM(sp.amount) OVER (PARTITION BY sp.sale_id ORDER BY sp.id) as running_total
                    FROM sale_payments sp
                    INNER JOIN invoice_sale inv ON sp.sale_id = inv.invoice_id
                    WHERE inv.cashier_user_id = %s
                      AND inv.sale_date BETWEEN %s AND %s
                      AND LOWER(sp.payment_method) IN ('card', 'credit card', 'debit card', 'credit/debit card')
                ) sub
            """, (user_id, login_time, end_time))
        
        card_result = cursor.fetchone()
        total_card_sales = float(card_result['total_card'] or 0)
        
        print(f"   Card Sales: Rs. {total_card_sales}")
        
        # ============================================
        # CALCULATE VARIANCE & CLOSING BALANCE
        # ============================================
        variance = counted_cash - expected_cash
        closing_balance = counted_cash  # What was counted
        actual_cash = counted_cash  # Same as counted
        
        print(f"\n🧮 Calculations:")
        print(f"   Variance: Rs. {variance} ({('OVER' if variance > 0 else 'SHORT') if variance != 0 else 'EXACT'})")
        print(f"   Closing Balance: Rs. {closing_balance}")
        
        # ============================================
        # ✅ UPDATE REGISTER LOG - USING ACTUAL SCHEMA COLUMNS
        # ============================================
        update_query = """
            UPDATE close_register_logs
            SET 
                -- Timing
                logout_time = %s,
                
                -- Cash amounts (ACTUAL DATABASE COLUMN NAMES)
                opening_cash = %s,
                counted_cash = %s,
                total_cash_added = %s,
                total_cash_removed = %s,
                opening_balance = %s,
                closing_balance = %s,
                expected_cash = %s,
                actual_cash = %s,
                cash_difference = %s,
                total_cash = %s,
                
                -- Sales tracking
                total_sale_amount = %s,
                total_card_sales = %s,
                
                -- Deductions
                sale_return_amount = %s,
                total_expense = %s,
                
                -- Metadata
                notes = %s,
                status = 'close'
                
            WHERE id = %s
              AND user_id = %s
              AND status = 'open'
        """
        
        update_params = (
            # Timing
            end_time,
            
            # Cash amounts (using ACTUAL column names)
            opening_cash,           # opening_cash
            counted_cash,           # counted_cash (actual counted)
            total_cash_added,       # total_cash_added (manual additions)
            total_cash_removed,     # total_cash_removed (manual removals)
            opening_balance,        # opening_balance
            closing_balance,        # closing_balance
            expected_cash,          # expected_cash
            actual_cash,            # actual_cash
            variance,               # cash_difference
            expected_cash,          # total_cash
            
            # Sales
            total_sale_amount,      # total_sale_amount
            total_card_sales,       # total_card_sales
            
            # Deductions
            sale_return_amount,     # sale_return_amount
            total_expense,          # total_expense
            
            # Metadata
            notes,                  # notes
            
            # WHERE conditions
            register_id,            # id
            user_id                 # user_id
        )
        
        print(f"\n📝 Executing UPDATE query...")
        
        cursor.execute(update_query, update_params)
        rows_affected = cursor.rowcount
        
        print(f"   ✅ Rows affected: {rows_affected}")
        
        if rows_affected == 0:
            cursor.close()
            conn.close()
            return jsonify({
                'success': False,
                'error': 'Failed to close register. It may already be closed or does not exist.'
            }), 400
        
        conn.commit()
        
        # ============================================
        # VERIFY DATA WAS SAVED
        # ============================================
        cursor.execute("""
            SELECT 
                opening_cash,
                counted_cash,
                total_cash_added,
                total_cash_removed,
                opening_balance,
                closing_balance,
                expected_cash,
                actual_cash,
                cash_difference,
                total_cash,
                total_sale_amount,
                total_card_sales,
                sale_return_amount,
                total_expense,
                status,
                logout_time
            FROM close_register_logs
            WHERE id = %s
        """, (register_id,))
        
        saved_data = cursor.fetchone()
        
        print(f"\n✅ DATA VERIFICATION:")
        print(f"   Status: {saved_data['status']}")
        print(f"   Opening Cash: Rs. {saved_data['opening_cash']}")
        print(f"   Counted Cash: Rs. {saved_data['counted_cash']}")
        print(f"   Expected Cash: Rs. {saved_data['expected_cash']}")
        print(f"   Variance: Rs. {saved_data['cash_difference']}")
        print(f"   Total Sales: Rs. {saved_data['total_sale_amount']}")
        print(f"   Card Sales: Rs. {saved_data['total_card_sales']}")
        print(f"   Returns: Rs. {saved_data['sale_return_amount']}")
        print(f"   Expenses: Rs. {saved_data['total_expense']}")
        print(f"   Logout Time: {saved_data['logout_time']}")
        
        cursor.close()
        conn.close()
        
        # ============================================
        # SUCCESS LOGGING
        # ============================================
        print(f"\n" + "="*50)
        print("✅ REGISTER CLOSED SUCCESSFULLY")
        print("="*50)
        print(f"   User ID: {user_id}")
        print(f"   Register ID: {register_id}")
        print(f"   Store ID: {store_id}")
        print(f"   Opening Cash: Rs. {opening_cash}")
        print(f"   Expected Cash: Rs. {expected_cash}")
        print(f"   Counted Cash: Rs. {counted_cash}")
        print(f"   Variance: Rs. {variance}")
        print(f"   Total Sales: Rs. {total_sale_amount}")
        print(f"   Card Sales: Rs. {total_card_sales}")
        print(f"   Returns: Rs. {sale_return_amount}")
        print(f"   Expenses: Rs. {total_expense}")
        print(f"   Closing Balance: Rs. {closing_balance}")
        print("="*50 + "\n")
        
        # ============================================
        # RETURN SUCCESS RESPONSE
        # ============================================
        return jsonify({
            'success': True,
            'message': 'Register closed successfully',
            'data': {
                'register_id': register_id,
                'variance': float(variance),
                'expected_cash': float(expected_cash),
                'actual_cash': float(counted_cash),
                'closing_balance': float(closing_balance),
                'total_card_sales': float(total_card_sales),
                'total_sale_amount': float(total_sale_amount),
                'sale_return_amount': float(sale_return_amount),
                'total_expense': float(total_expense),
                'status': 'close'
            }
        }), 200
        
    except Exception as e:
        print(f"\n" + "="*50)
        print("❌ ERROR CLOSING REGISTER")
        print("="*50)
        print(f"Error: {str(e)}")
        import traceback
        traceback.print_exc()
        print("="*50 + "\n")
        
        if 'conn' in locals() and conn:
            conn.rollback()
        
        return jsonify({
            'success': False,
            'error': str(e)
        }), 500
        
    finally:
        if 'cursor' in locals() and cursor:
            cursor.close()
        if 'conn' in locals() and conn:
            conn.close()


    
# ============================================
# ✅ COMPLETE: GET INVOICE (Alternative)
# ============================================
@sale_bp.route('/get_invoice/<int:invoice_id>', methods=['GET'])
@jwt_required()
@role_required('admin', 'cashier')
def get_invoice(invoice_id):
    """
    Get invoice details (alternative endpoint).
    ✅ Compatible with database schema
    """
    try:
        conn = get_db_connection()
        cursor = conn.cursor(dictionary=True)

        cursor.execute("SELECT * FROM invoice_sale WHERE invoice_id = %s", (invoice_id,))
        invoice = cursor.fetchone()
        
        if not invoice:
            return jsonify({"error": "Invoice not found"}), 404

        cursor.execute("""
            SELECT sp.*, p.product_name 
            FROM sale_products sp 
            JOIN products p ON sp.product_id = p.id 
            WHERE sp.invoice_id = %s
        """, (invoice_id,))
        products = cursor.fetchall()

        # Convert Decimal to float
        for key, value in invoice.items():
            if isinstance(value, Decimal):
                invoice[key] = float(value)
        
        for product in products:
            for key, value in product.items():
                if isinstance(value, Decimal):
                    product[key] = float(value)

        result = {
            "invoice": invoice,
            "products": products
        }

        cursor.close()
        conn.close()

        return jsonify(result), 200

    except Exception as e:
        print("❌ Error in /get_invoice:", e)
        import traceback
        traceback.print_exc()
        return jsonify({"error": str(e)}), 500


# ============================================
# HOLD BILL - Save to Database
# ✅ our_price column removed — single sp.price used
# ============================================
@sale_bp.route('/hold_bill', methods=['POST'])
@jwt_required()
@role_required('admin', 'cashier')
def hold_bill():
    """
    Save a bill as 'hold' status in database.
    ✅ Stock is NOT deducted yet (no sale_product_items entries)
    ✅ our_price column removed — single price stored
    """
    conn = None
    try:
        data = request.get_json()
        print("📥 Hold Bill Data:")
        print(json.dumps(data, indent=4, default=str))

        customer_id          = data.get('customer_id')
        customer_name        = data.get('customer_name', 'Walk-in Customer')
        warehouse_id         = data.get('warehouse_id')
        store_id             = data.get('store_id')
        products             = data.get('products', [])
        medicine_grand_total = float(data.get('medicineGrandTotal', 0))
        discount_value       = float(data.get('discountValue', 0))
        grand_total          = float(data.get('grandTotal', 0))

        if not warehouse_id:
            return jsonify({"error": "Warehouse ID is required"}), 400
        if not store_id:
            return jsonify({"error": "Store ID is required"}), 400
        if not products or len(products) == 0:
            return jsonify({"error": "No products to hold"}), 400

        current_user = get_jwt_identity()

        conn   = get_db_connection()
        cursor = conn.cursor(dictionary=True)

        cursor.execute("START TRANSACTION")

        # Generate hold invoice code
        today_str = datetime.now().strftime("%Y%m%d")
        cursor.execute("SELECT COUNT(*) as count FROM invoice_sale WHERE DATE(created_at) = CURDATE()")
        count        = cursor.fetchone()["count"] + 1
        invoice_code = f"HOLD-{today_str}-{count:04d}"

        cursor.execute("""
            INSERT INTO invoice_sale (
                invoice_code, tenderAmount, remainingAmount, invoiceTotal,
                payment_method, payment_status, customer_id, discount, status,
                warehouse_id, store_id, cashier_user_id, tax
            ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
        """, (
            invoice_code, 0, grand_total, grand_total,
            'unpaid', 'unpaid', customer_id, discount_value, 'hold',
            warehouse_id, store_id, current_user, 0
        ))

        invoice_id = cursor.lastrowid

        # ✅ Insert products WITHOUT deducting stock — single price
        for idx, product in enumerate(products, 1):
            product_id   = int(product['product_id'])
            variation_id = product.get('variation_id')
            if variation_id:
                variation_id = int(variation_id)

            quantity             = float(product['quantity'])
            discount_value_prod  = float(product.get('discountValue', 0))
            total                = float(product.get('total', 0))
            product_warehouse_id = product.get('warehouse_id') or warehouse_id

            # ✅ Single selling price (our_price / market_price removed)
            selling_price = float(product.get('product_price') or product.get('price', 0))

            sale_unit_identifier = get_unit_identifier(cursor, product.get('sales_unit'))

            print(f"   [{idx}] Product {product_id}: Qty {quantity}, Price: Rs.{selling_price}")

            # ✅ INSERT — our_price column removed
            cursor.execute("""
                INSERT INTO sale_products (
                    invoice_id, product_id, variation_id, warehouse_id,
                    price, quantity, total,
                    discount_type, product_discount,
                    tax_type, product_tax,
                    discount, tax, unit_price, sale_unit
                ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
            """, (
                invoice_id, product_id, variation_id, product_warehouse_id,
                selling_price, quantity, total,
                product.get('discountType'), float(product.get('discount', 0)),
                product.get('tax_type'),     float(product.get('product_tax', 0)),
                discount_value_prod,         float(product.get('tax', 0)),
                selling_price, sale_unit_identifier
            ))

        conn.commit()
        cursor.close()
        conn.close()

        print(f"✅ Bill {invoice_code} saved as HOLD (stock not deducted)")
        print(f"   Customer: {customer_name}")
        print(f"   Total Products: {len(products)}")

        return jsonify({
            "message":      "Bill held successfully",
            "invoice_id":   invoice_id,
            "invoice_code": invoice_code
        }), 200

    except Exception as e:
        if conn: conn.rollback(); conn.close()
        print("❌ Error holding bill:", e)
        import traceback; traceback.print_exc()
        return jsonify({"error": str(e)}), 500


# ============================================
# GET HOLD BILLS LIST
# ============================================
@sale_bp.route('/get_hold_bills', methods=['GET'])
@jwt_required()
@role_required('admin', 'cashier')
def get_hold_bills():
    """Get all hold bills from database."""
    try:
        conn   = get_db_connection()
        cursor = conn.cursor(dictionary=True)

        cursor.execute("""
            SELECT
                i.invoice_id, i.invoice_code, i.invoiceTotal,
                i.discount, i.sale_date, i.customer_id,
                i.warehouse_id, i.store_id,
                c.name            AS customer_name,
                w.warehouse_name,
                COUNT(sp.sale_product_id) AS item_count
            FROM invoice_sale i
            LEFT JOIN customers  c  ON i.customer_id  = c.id
            LEFT JOIN warehouses w  ON i.warehouse_id = w.id
            LEFT JOIN sale_products sp ON i.invoice_id = sp.invoice_id
            WHERE i.status = 'hold'
            GROUP BY i.invoice_id
            ORDER BY i.sale_date DESC
        """)

        hold_bills = cursor.fetchall()
        result     = []

        for bill in hold_bills:
            result.append({
                "invoice_id":    bill['invoice_id'],
                "reference":     bill['invoice_code'],
                "timestamp":     bill['sale_date'].isoformat() if bill['sale_date'] else '',
                "customerId":    bill['customer_id'],
                "customerName":  bill['customer_name'] or 'Walk-in Customer',
                "warehouseId":   bill['warehouse_id'],
                "warehouseName": bill['warehouse_name'] or 'Unknown',
                "storeId":       bill['store_id'],
                "discountValue": float(bill['discount'])     if bill['discount']     else 0.0,
                "grandTotal":    float(bill['invoiceTotal']) if bill['invoiceTotal'] else 0.0,
                "itemCount":     bill['item_count']
            })

        cursor.close()
        conn.close()

        return jsonify({"hold_bills": result}), 200

    except Exception as e:
        print("❌ Error getting hold bills:", e)
        import traceback; traceback.print_exc()
        return jsonify({"error": str(e)}), 500


# ============================================
# GET SINGLE HOLD BILL DETAILS
# ✅ our_price removed — single sp.price used
# ============================================
@sale_bp.route('/get_hold_bill/<int:invoice_id>', methods=['GET'])
@jwt_required()
@role_required('admin', 'cashier')
def get_hold_bill(invoice_id):
    """
    Get detailed information for a single hold bill.
    ✅ our_price column removed — single sp.price used
    ✅ Does not reference non-existent sales_unit table
    """
    try:
        conn   = get_db_connection()
        cursor = conn.cursor(dictionary=True)

        cursor.execute("""
            SELECT i.*, c.name AS customer_name, w.warehouse_name
            FROM invoice_sale i
            LEFT JOIN customers  c ON i.customer_id  = c.id
            LEFT JOIN warehouses w ON i.warehouse_id = w.id
            WHERE i.invoice_id = %s AND i.status = 'hold'
        """, (invoice_id,))

        invoice = cursor.fetchone()
        if not invoice:
            return jsonify({"error": "Hold bill not found"}), 404

        # ✅ Single selling price — sp.our_price removed
        cursor.execute("""
            SELECT
                sp.sale_product_id,
                sp.product_id, sp.variation_id, sp.quantity, sp.warehouse_id,
                sp.price        AS selling_price,
                sp.total, sp.discount,
                sp.discount_type, sp.product_discount,
                sp.tax_type, sp.product_tax,
                sp.unit_price, sp.sale_unit,
                p.product_name, p.sku,
                pv.variation_name, pv.variation_type, pv.variation_sku
            FROM sale_products sp
            JOIN products p ON sp.product_id = p.id
            LEFT JOIN product_variations pv ON sp.variation_id = pv.id
            WHERE sp.invoice_id = %s
        """, (invoice_id,))

        products = cursor.fetchall()

        products_list = []
        for idx, prod in enumerate(products, 1):
            product_warehouse = prod.get('warehouse_id') or invoice['warehouse_id']

            # ✅ Single selling price
            selling_price = float(prod.get('selling_price') or prod.get('unit_price', 0))

            # Check current stock at this price
            available = get_available_stock_by_warehouse(
                cursor, prod['product_id'], prod['variation_id'],
                product_warehouse, price=selling_price
            )

            print(f"   [{idx}] Product {prod['product_id']}: Price Rs.{selling_price}, Stock: {available}")

            # Interpret sale_unit field
            sale_unit_value = prod.get('sale_unit')
            sale_unit_name  = 'Unit'
            if sale_unit_value:
                sale_unit_name = str(sale_unit_value) if not isinstance(sale_unit_value, str) else sale_unit_value

            products_list.append({
                "product_id":       prod['product_id'],
                "variation_id":     prod['variation_id'],
                "product_name":     prod['product_name'],
                "variation_name":   prod.get('variation_name'),
                "variation_type":   prod.get('variation_type'),
                "sku":              prod.get('variation_sku') or prod.get('sku'),
                "quantity":         float(prod['quantity']),
                "displayPrice":     selling_price,
                "discountValue":    float(prod['discount'])          if prod['discount']          else 0.0,
                "total":            float(prod['total']),
                # ✅ Single selling price — market_price / our_price removed
                "selling_price":    selling_price,
                "product_price":    selling_price,    # for compatibility
                "price":            selling_price,    # for compatibility
                "product_quantity": available,
                "warehouse_id":     product_warehouse,
                "warehouse_name":   invoice['warehouse_name'],
                "discountType":     prod['discount_type'],
                "discount":         float(prod['product_discount']) if prod['product_discount'] else 0.0,
                "tax_type":         prod['tax_type'],
                "product_tax":      float(prod['product_tax'])      if prod['product_tax']      else 0.0,
                "sales_unit":       prod['sale_unit'],
                "sale_unit_name":   sale_unit_name
            })

        cursor.close()
        conn.close()

        print(f"✅ Hold bill {invoice['invoice_code']} retrieved with {len(products_list)} products")

        return jsonify({
            "invoice_id":         invoice['invoice_id'],
            "invoice_code":       invoice['invoice_code'],
            "customer_id":        invoice['customer_id'],
            "customer_name":      invoice['customer_name'] or 'Walk-in Customer',
            "warehouse_id":       invoice['warehouse_id'],
            "warehouse_name":     invoice['warehouse_name'] or 'Unknown',
            "store_id":           invoice['store_id'],
            "discountValue":      float(invoice['discount'])     if invoice['discount']     else 0.0,
            "medicineGrandTotal": float(invoice['invoiceTotal']) + float(invoice['discount'] or 0),
            "grandTotal":         float(invoice['invoiceTotal']),
            "products":           products_list
        }), 200

    except Exception as e:
        print("❌ Error getting hold bill:", e)
        import traceback; traceback.print_exc()
        return jsonify({"error": str(e)}), 500


# ============================================
# DELETE HOLD BILL
# ============================================
@sale_bp.route('/delete_hold_bill/<int:invoice_id>', methods=['DELETE'])
@jwt_required()
@role_required('admin', 'cashier')
def delete_hold_bill(invoice_id):
    """
    Delete a hold bill from database.
    ✅ No stock changes needed (stock was never deducted)
    """
    conn = None
    try:
        conn   = get_db_connection()
        cursor = conn.cursor(dictionary=True)

        cursor.execute("START TRANSACTION")

        cursor.execute("""
            SELECT invoice_code, status
            FROM invoice_sale
            WHERE invoice_id = %s
        """, (invoice_id,))

        invoice = cursor.fetchone()
        if not invoice:
            return jsonify({"error": "Hold bill not found"}), 404
        if invoice['status'] != 'hold':
            return jsonify({"error": "Invoice is not a hold bill"}), 400

        cursor.execute("DELETE FROM sale_products WHERE invoice_id = %s",  (invoice_id,))
        cursor.execute("DELETE FROM invoice_sale  WHERE invoice_id = %s",  (invoice_id,))

        conn.commit()
        cursor.close()
        conn.close()

        print(f"✅ Hold bill {invoice['invoice_code']} deleted successfully")

        return jsonify({
            "message": f"Hold bill {invoice['invoice_code']} has been deleted successfully"
        }), 200

    except Exception as e:
        if conn: conn.rollback(); conn.close()
        print("❌ Error deleting hold bill:", e)
        import traceback; traceback.print_exc()
        return jsonify({"error": str(e)}), 500


# ============================================
# RESUME HOLD BILL
# ✅ our_price column removed — single sp.price used
# ✅ Does not reference non-existent sales_unit table
# ============================================
@sale_bp.route('/resume_hold_bill/<int:invoice_id>', methods=['POST'])
@jwt_required()
@role_required('admin', 'cashier')
def resume_hold_bill(invoice_id):
    """
    Resume a hold bill and convert it to a sale.
    ✅ Uses sp.price (single selling price) for stock deduction
    ✅ Uses INTEGER IDs for payment_method in sale_payments table
    ✅ our_price column removed
    """
    conn = None
    try:
        data = request.get_json()

        tender_amount    = float(data.get('tenderAmount', 0))
        status           = data.get('status', 'received')
        payment_methods  = data.get('payment_methods', [])
        payment_status   = data.get('payment_status', 'paid')

        conn   = get_db_connection()
        cursor = conn.cursor(dictionary=True)

        cursor.execute("SET TRANSACTION ISOLATION LEVEL READ COMMITTED")
        cursor.execute("START TRANSACTION")

        # Get hold bill
        cursor.execute("""
            SELECT * FROM invoice_sale
            WHERE invoice_id = %s AND status = 'hold'
        """, (invoice_id,))

        invoice = cursor.fetchone()
        if not invoice:
            return jsonify({"error": "Hold bill not found"}), 404

        warehouse_id    = invoice['warehouse_id']
        store_id        = invoice['store_id']
        invoice_total   = float(invoice['invoiceTotal'])
        remaining_amount = max(0, invoice_total - tender_amount)

        # Determine payment method
        if len(payment_methods) == 0 or payment_status == 'unpaid':
            primary_payment_method = "unpaid"
            payment_status         = "unpaid"
        elif len(payment_methods) == 1:
            method_id, method_name = get_payment_method_id_and_name(cursor, payment_methods[0]['type'])
            primary_payment_method = method_name.lower()
        else:
            primary_payment_method = "multiple"

        # ✅ Get products — single sp.price (our_price removed)
        cursor.execute("""
            SELECT
                sale_product_id, product_id, variation_id,
                quantity, warehouse_id,
                price AS selling_price
            FROM sale_products
            WHERE invoice_id = %s
        """, (invoice_id,))
        products = cursor.fetchall()

        print(f"\n📦 === RESUMING HOLD BILL {invoice['invoice_code']} ===")
        print(f"   Processing {len(products)} products...")

        for idx, product in enumerate(products, 1):
            product_id        = product['product_id']
            variation_id      = product['variation_id']
            quantity          = float(product['quantity'])
            selling_price     = float(product.get('selling_price', 0))
            product_warehouse = product.get('warehouse_id') or warehouse_id
            sale_product_id   = product['sale_product_id']

            print(f"\n   [{idx}/{len(products)}] Product {product_id}:")
            print(f"      Selling Price: Rs. {selling_price:.2f}")
            print(f"      Quantity:      {quantity}")
            print(f"      Warehouse:     {product_warehouse}")

            if quantity <= 0:
                raise Exception(f"Invalid quantity {quantity} for product {product_id}")

            # Check stock at selling price
            available = get_available_stock_by_warehouse(
                cursor, product_id, variation_id, product_warehouse, price=selling_price
            )

            print(f"      Available Stock: {available:.2f}")

            if available < quantity:
                raise Exception(
                    f"Insufficient stock at price Rs. {selling_price:.2f} for product {product_id} "
                    f"in warehouse {product_warehouse}. Available: {available:.2f}, Required: {quantity:.2f}"
                )

            # ✅ Deduct stock using single selling price
            batches_used = deduct_stock_from_batches_fifo_warehouse(
                cursor, product_id, variation_id, quantity,
                product_warehouse, store_id, sale_product_id,
                price=selling_price
            )

            print(f"      ✅ Stock deducted: {len(batches_used)} batch(es) used")

        # ✅ Insert payment methods with INTEGER IDs
        if payment_methods and payment_status != 'unpaid':
            print(f"\n💳 Processing {len(payment_methods)} payment method(s)...")
            for idx, payment in enumerate(payment_methods, 1):
                try:
                    payment_type_value = payment.get('type')
                    payment_amount     = float(payment.get('amount', 0))

                    if payment_amount > 0:
                        method_id, method_name = get_payment_method_id_and_name(cursor, payment_type_value)

                        print(f"   [{idx}] {method_name} (ID: {method_id}): Rs. {payment_amount:.2f}")

                        cursor.execute("""
                            INSERT INTO sale_payments (sale_id, payment_method, amount, payment_date)
                            VALUES (%s, %s, %s, NOW())
                        """, (invoice_id, method_id, payment_amount))

                        print(f"      ✅ Payment record created")

                except Exception as pe:
                    print(f"   ❌ Error processing payment {idx}: {pe}")
                    raise
        else:
            print(f"\n💳 No payments (Status: {payment_status})")

        # Update invoice status
        cursor.execute("""
            UPDATE invoice_sale
            SET status           = %s,
                payment_method   = %s,
                payment_status   = %s,
                tenderAmount     = %s,
                remainingAmount  = %s,
                sale_date        = NOW()
            WHERE invoice_id = %s
        """, (status, primary_payment_method, payment_status, tender_amount, remaining_amount, invoice_id))

        conn.commit()
        cursor.close()
        conn.close()

        print(f"\n✅ === HOLD BILL RESUMED SUCCESSFULLY ===")
        print(f"   Invoice:        {invoice['invoice_code']}")
        print(f"   Products:       {len(products)}")
        print(f"   Status:         {status}")
        print(f"   Payment Status: {payment_status}")
        print(f"   Tender:         Rs. {tender_amount:.2f}")
        print(f"   Remaining:      Rs. {remaining_amount:.2f}")

        return jsonify({
            "message":        "Hold bill resumed successfully",
            "invoice_id":     invoice_id,
            "invoice_code":   invoice['invoice_code'],
            "status":         status,
            "payment_status": payment_status
        }), 200

    except Exception as e:
        if conn: conn.rollback(); conn.close()
        print("\n❌ === ERROR RESUMING HOLD BILL ===")
        print(f"   Error: {str(e)}")
        import traceback; traceback.print_exc()
        return jsonify({"error": str(e)}), 500