from flask import Blueprint, request, jsonify
from flask_jwt_extended import jwt_required, get_jwt_identity
from db.db import get_db_connection
from config.auth import role_required
from datetime import datetime, timedelta
import traceback

purchase_return_bp = Blueprint('purchase_return', __name__)


def safe_float(value, default=0.0):
    """Safely convert to float"""
    try:
        return float(value) if value not in (None, '', 'null') else default
    except (ValueError, TypeError):
        return default


def safe_int(value, default=0):
    """Safely convert to int"""
    try:
        return int(value) if value not in (None, '', 'null') else default
    except (ValueError, TypeError):
        return default


def restore_stock_with_batch_tracking(cursor, store_id, warehouse_id, product_id, variation_id, quantity, order_id):
    """
    ✅ BATCH-AWARE: Restore stock to warehouse_stock matching specific batch_ids from this PO
    This function restores stock when updating or deleting a purchase return
    """
    print(f"  🔄 Restoring {quantity} units with batch tracking...")
    
    # Get batches from THIS PO (most recent first for restoration)
    if variation_id:
        cursor.execute("""
            SELECT pb.batch_id, pb.remaining_quantity
            FROM product_batches pb
            INNER JOIN grn g ON pb.grn_id = g.grn_id
            WHERE pb.variation_id = %s
            AND g.purchase_order_id = %s
            AND g.status = 'completed'
            AND g.warehouse_id = %s
            AND g.store_id = %s
            ORDER BY pb.batch_id DESC
        """, (variation_id, order_id, warehouse_id, store_id))
    else:
        cursor.execute("""
            SELECT pb.batch_id
            FROM product_batches pb
            INNER JOIN grn g ON pb.grn_id = g.grn_id
            WHERE pb.product_id = %s 
            AND pb.variation_id IS NULL
            AND g.purchase_order_id = %s
            AND g.status = 'completed'
            AND g.warehouse_id = %s
            AND g.store_id = %s
            ORDER BY pb.batch_id DESC
        """, (product_id, order_id, warehouse_id, store_id))
    
    batches = cursor.fetchall()
    remaining_to_restore = quantity
    
    for batch in batches:
        if remaining_to_restore <= 0:
            break
            
        batch_id = batch['batch_id']
        
        # 1. Restore to product_batches
        cursor.execute("""
            UPDATE product_batches
            SET remaining_quantity = remaining_quantity + %s
            WHERE batch_id = %s
        """, (remaining_to_restore, batch_id))
        
        print(f"    ✅ Restored {remaining_to_restore} to batch {batch_id}")
        
        # 2. ✅ CRITICAL: Restore to warehouse_stock with THIS batch_id
        cursor.execute("""
            SELECT id, quantity
            FROM warehouse_stock
            WHERE store_id = %s 
              AND warehouse_id = %s
              AND product_id = %s
              AND (variation_id IS NULL AND %s IS NULL OR variation_id = %s)
              AND batch_id = %s
            LIMIT 1
        """, (store_id, warehouse_id, product_id, variation_id, variation_id, batch_id))
        
        warehouse_record = cursor.fetchone()
        
        if warehouse_record:
            # Update existing record
            current_qty = safe_float(warehouse_record['quantity'])
            new_qty = current_qty + remaining_to_restore
            
            cursor.execute("""
                UPDATE warehouse_stock
                SET quantity = %s, updated_at = NOW()
                WHERE id = %s
            """, (new_qty, warehouse_record['id']))
            
            print(f"    ✅ Warehouse stock (batch {batch_id}): {current_qty} + {remaining_to_restore} = {new_qty}")
        else:
            # Create new warehouse_stock record for this batch
            cursor.execute("""
                INSERT INTO warehouse_stock 
                (store_id, warehouse_id, product_id, variation_id, batch_id, quantity, updated_at)
                VALUES (%s, %s, %s, %s, %s, %s, NOW())
            """, (store_id, warehouse_id, product_id, variation_id, batch_id, remaining_to_restore))
            
            print(f"    ✅ Created new warehouse_stock record for batch {batch_id} with qty {remaining_to_restore}")
        
        remaining_to_restore = 0


# ============================================
# CREATE PURCHASE RETURN (✅ BATCH-AWARE STOCK DEDUCTION)
# ============================================
@purchase_return_bp.route('/submit_purchase_return', methods=['POST'])
@jwt_required()
@role_required('admin', 'manager')
def submit_purchase_return():
    """
    ✅ BATCH-AWARE: Create Purchase Return with proper batch tracking
    
    Flow:
    1. Validate purchase order exists
    2. Validate stock availability from THIS PO's GRN batches
    3. Insert return header and items
    4. If status='Received':
       - Deduct from product_batches (FIFO from THIS PO's GRN)
       - Deduct from warehouse_stock records MATCHING specific batch_ids
    """
    try:
        data = request.get_json(force=True)
        if not data:
            return jsonify({'error': 'No data provided'}), 400
    except Exception as e:
        return jsonify({'error': 'Invalid JSON data', 'message': str(e)}), 400

    print("=" * 80)
    print("📦 PURCHASE RETURN CREATION (BATCH-AWARE)")
    print("=" * 80)

    # Extract data
    order_id = safe_int(data.get('order_id'))
    supplier_id = safe_int(data.get('supplier_id'))
    warehouse_id = safe_int(data.get('warehouse_id'))
    store_id = safe_int(data.get('store_id'))
    return_note = data.get('return_note', '')
    return_tax = safe_float(data.get('return_tax', 0.00))
    return_discount = safe_float(data.get('return_discount', 0.00))
    return_total = safe_float(data.get('return_total', 0.00))
    return_status = data.get('return_status', 'Pending')
    return_items = data.get('return_items', [])

    current_user = get_jwt_identity()

    # Validation
    if not order_id or not supplier_id:
        return jsonify({'error': 'Missing required fields: order_id or supplier_id'}), 400
    if not store_id:
        return jsonify({'error': 'Store ID is required'}), 400
    if not warehouse_id:
        return jsonify({'error': 'Warehouse ID is required'}), 400
    if not return_items or len(return_items) == 0:
        return jsonify({'error': 'At least one return item is required'}), 400

    conn = get_db_connection()
    if not conn:
        return jsonify({'error': 'Database connection failed'}), 500

    cursor = conn.cursor(dictionary=True, buffered=True)

    try:
        conn.start_transaction()
        print("\n🔄 Transaction started")

        # Step 1: Verify purchase order exists
        print(f"\n📝 Step 1: Validating purchase order {order_id}...")
        cursor.execute("""
            SELECT order_id, status, grand_total, warehouse_id, store_id
            FROM purchase_orders
            WHERE order_id = %s AND supplier_id = %s
        """, (order_id, supplier_id))
        
        purchase_order = cursor.fetchone()
        if not purchase_order:
            raise ValueError(f'Purchase order {order_id} not found for supplier {supplier_id}')
        
        if purchase_order['warehouse_id'] != warehouse_id:
            raise ValueError(f"Warehouse mismatch: PO warehouse is {purchase_order['warehouse_id']}, but return specifies {warehouse_id}")
        
        if purchase_order['store_id'] != store_id:
            raise ValueError(f"Store mismatch: PO store is {purchase_order['store_id']}, but return specifies {store_id}")

        print(f"  ✅ PO validated - Warehouse: {warehouse_id}, Store: {store_id}")

        # Step 2: Insert purchase return header
        print(f"\n📝 Step 2: Creating purchase return header...")
        return_date = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        
        cursor.execute("""
            INSERT INTO purchase_returns (
                order_id, supplier_id, warehouse_id, store_id, return_note,
                return_tax, return_discount, return_total,
                return_status, created_on
            )
            VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
        """, (
            order_id, supplier_id, warehouse_id, store_id, return_note,
            return_tax, return_discount, return_total,
            return_status, return_date
        ))
        
        return_id = cursor.lastrowid
        print(f"  ✅ Purchase return created with ID: {return_id}")

        # Step 3: Process return items
        print(f"\n📝 Step 3: Processing {len(return_items)} items...")
        items_processed = 0

        for idx, item in enumerate(return_items, 1):
            try:
                product_id = safe_int(item.get('product_id'))
                variation_id = safe_int(item.get('variation_id')) if item.get('variation_id') else None
                quantity = safe_float(item.get('quantity', 0))
                price = safe_float(item.get('price', 0.00))
                unit_price = safe_float(item.get('unit_price', price))
                discount = safe_float(item.get('discount', 0.00))
                tax = safe_float(item.get('tax', 0.00))
                subtotal = safe_float(item.get('subtotal', 0.00))
                discount_type = item.get('discount_type', 'fixed')
                product_discount = safe_float(item.get('product_discount', 0.00))
                tax_type = item.get('tax_type', 'exclusive')
                product_tax = safe_float(item.get('product_tax', 0.00))

                if quantity <= 0:
                    print(f"  ⚠️  Skipping item {idx}: quantity is 0")
                    continue

                print(f"\n  📦 Item {idx}: Product {product_id}, Qty {quantity}")

                # Validate product exists
                cursor.execute("SELECT id, product_name FROM products WHERE id = %s", (product_id,))
                product = cursor.fetchone()
                if not product:
                    raise ValueError(f"Product ID {product_id} does not exist")
                
                print(f"     Product: {product['product_name']}")

                # Validate variation if provided
                if variation_id:
                    cursor.execute("""
                        SELECT id, variation_name FROM product_variations 
                        WHERE id = %s AND product_id = %s
                    """, (variation_id, product_id))
                    variation = cursor.fetchone()
                    if not variation:
                        raise ValueError(f"Variation ID {variation_id} invalid for product {product_id}")
                    print(f"     Variation: {variation['variation_name']}")

                # Validate stock availability if status is 'Received'
                if return_status == 'Received':
                    print(f"     Status: Received - Validating stock from THIS PO's GRN...")
                    
                    # Check batch stock ONLY FROM THIS PURCHASE ORDER'S GRN
                    if variation_id:
                        cursor.execute("""
                            SELECT COALESCE(SUM(pb.remaining_quantity), 0) as batch_stock
                            FROM product_batches pb
                            INNER JOIN grn g ON pb.grn_id = g.grn_id
                            WHERE pb.variation_id = %s 
                            AND pb.remaining_quantity > 0
                            AND g.purchase_order_id = %s
                            AND g.status = 'completed'
                            AND g.warehouse_id = %s
                            AND g.store_id = %s
                        """, (variation_id, order_id, warehouse_id, store_id))
                    else:
                        cursor.execute("""
                            SELECT COALESCE(SUM(pb.remaining_quantity), 0) as batch_stock
                            FROM product_batches pb
                            INNER JOIN grn g ON pb.grn_id = g.grn_id
                            WHERE pb.product_id = %s 
                            AND pb.variation_id IS NULL 
                            AND pb.remaining_quantity > 0
                            AND g.purchase_order_id = %s
                            AND g.status = 'completed'
                            AND g.warehouse_id = %s
                            AND g.store_id = %s
                        """, (product_id, order_id, warehouse_id, store_id))
                    
                    result = cursor.fetchone()
                    batch_total_from_this_po = safe_float(result['batch_stock']) if result else 0.00
                    
                    print(f"     Batch Stock (THIS PO #{order_id}): {batch_total_from_this_po}")

                    if quantity > batch_total_from_this_po:
                        raise ValueError(
                            f"Cannot return {quantity} units of {product['product_name']}. " +
                            f"Only {batch_total_from_this_po} units available from Purchase Order #{order_id}."
                        )

                # Step 3a: Insert return item
                cursor.execute("""
                    INSERT INTO purchase_return_items (
                        return_id, product_id, variation_id, discount_type, product_discount,
                        tax_type, product_tax, quantity, price, unit_price,
                        discount, tax, subtotal, created_on
                    )
                    VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)
                """, (
                    return_id, product_id, variation_id, discount_type, product_discount,
                    tax_type, product_tax, quantity, price, unit_price,
                    discount, tax, subtotal, return_date
                ))
                print(f"     ✅ Return item created")

                # Step 3b: ✅ BATCH-AWARE STOCK DEDUCTION
                if return_status == 'Received':
                    print(f"     🔄 Deducting stock with batch tracking...")
                    
                    # Get batches from THIS PO's GRN (FIFO - oldest first)
                    if variation_id:
                        cursor.execute("""
                            SELECT pb.batch_id, pb.remaining_quantity
                            FROM product_batches pb
                            INNER JOIN grn g ON pb.grn_id = g.grn_id
                            WHERE pb.variation_id = %s 
                            AND pb.remaining_quantity > 0
                            AND g.purchase_order_id = %s
                            AND g.status = 'completed'
                            AND g.warehouse_id = %s
                            AND g.store_id = %s
                            ORDER BY pb.batch_id ASC
                        """, (variation_id, order_id, warehouse_id, store_id))
                    else:
                        cursor.execute("""
                            SELECT pb.batch_id, pb.remaining_quantity
                            FROM product_batches pb
                            INNER JOIN grn g ON pb.grn_id = g.grn_id
                            WHERE pb.product_id = %s 
                            AND pb.variation_id IS NULL 
                            AND pb.remaining_quantity > 0
                            AND g.purchase_order_id = %s
                            AND g.status = 'completed'
                            AND g.warehouse_id = %s
                            AND g.store_id = %s
                            ORDER BY pb.batch_id ASC
                        """, (product_id, order_id, warehouse_id, store_id))
                    
                    batches = cursor.fetchall()

                    if not batches:
                        raise ValueError(
                            f"No batches found for Purchase Order #{order_id}. " +
                            "Cannot process return."
                        )

                    remaining_to_deduct = quantity

                    # ✅ Process each batch: Update BOTH product_batches AND warehouse_stock
                    for batch in batches:
                        if remaining_to_deduct <= 0:
                            break
                        
                        batch_id = batch['batch_id']
                        batch_remaining = safe_float(batch['remaining_quantity'])
                        deduct_from_this_batch = min(batch_remaining, remaining_to_deduct)
                        
                        # 1. Deduct from product_batches
                        cursor.execute("""
                            UPDATE product_batches
                            SET remaining_quantity = remaining_quantity - %s
                            WHERE batch_id = %s
                        """, (deduct_from_this_batch, batch_id))
                        
                        print(f"     ✅ Batch {batch_id}: deducted {deduct_from_this_batch} from product_batches")
                        
                        # 2. ✅ CRITICAL: Deduct from warehouse_stock record with THIS batch_id
                        cursor.execute("""
                            SELECT id, quantity 
                            FROM warehouse_stock
                            WHERE store_id = %s 
                              AND warehouse_id = %s
                              AND product_id = %s
                              AND (variation_id IS NULL AND %s IS NULL OR variation_id = %s)
                              AND batch_id = %s
                            LIMIT 1
                        """, (store_id, warehouse_id, product_id, variation_id, variation_id, batch_id))
                        
                        warehouse_record = cursor.fetchone()
                        
                        if warehouse_record:
                            current_warehouse_qty = safe_float(warehouse_record['quantity'])
                            
                            # Deduct what we can from this specific warehouse_stock record
                            deduct_from_warehouse = min(current_warehouse_qty, deduct_from_this_batch)
                            new_warehouse_qty = current_warehouse_qty - deduct_from_warehouse
                            
                            cursor.execute("""
                                UPDATE warehouse_stock
                                SET quantity = %s, updated_at = NOW()
                                WHERE id = %s
                            """, (new_warehouse_qty, warehouse_record['id']))
                            
                            print(f"     ✅ Warehouse stock (batch {batch_id}): {current_warehouse_qty} - {deduct_from_warehouse} = {new_warehouse_qty}")
                            
                            remaining_to_deduct -= deduct_from_warehouse
                        else:
                            # ⚠️ Warehouse record not found for this batch - try generic records
                            print(f"     ⚠️  No warehouse_stock record for batch {batch_id}, checking fallback...")
                            
                            cursor.execute("""
                                SELECT id, quantity, batch_id
                                FROM warehouse_stock
                                WHERE store_id = %s 
                                  AND warehouse_id = %s
                                  AND product_id = %s
                                  AND (variation_id IS NULL AND %s IS NULL OR variation_id = %s)
                                  AND quantity > 0
                                ORDER BY 
                                    CASE WHEN batch_id IS NULL THEN 1 ELSE 0 END,
                                    id ASC
                                LIMIT 1
                            """, (store_id, warehouse_id, product_id, variation_id, variation_id))
                            
                            fallback_record = cursor.fetchone()
                            
                            if fallback_record:
                                current_qty = safe_float(fallback_record['quantity'])
                                deduct_now = min(current_qty, deduct_from_this_batch)
                                new_qty = current_qty - deduct_now
                                
                                cursor.execute("""
                                    UPDATE warehouse_stock
                                    SET quantity = %s, updated_at = NOW()
                                    WHERE id = %s
                                """, (new_qty, fallback_record['id']))
                                
                                print(f"     ✅ Used fallback warehouse record {fallback_record['id']}: {current_qty} - {deduct_now} = {new_qty}")
                                remaining_to_deduct -= deduct_now

                    if remaining_to_deduct > 0:
                        raise ValueError(
                            f"Stock shortage: Could not deduct {remaining_to_deduct} units of {product['product_name']} " +
                            f"from Purchase Order #{order_id}'s warehouse stock"
                        )

                items_processed += 1

            except Exception as item_error:
                print(f"  ❌ Error processing item {idx}: {item_error}")
                raise item_error

        # Commit transaction
        conn.commit()
        print(f"\n✅ Transaction committed! Items: {items_processed}/{len(return_items)}")
        
        print("\n" + "=" * 80)
        print("✅ PURCHASE RETURN COMPLETE")
        print("=" * 80)
        print(f"✅ Return ID: {return_id}")
        print(f"✅ Purchase Order: #{order_id}")
        print(f"✅ Status: {return_status}")
        print(f"✅ Items: {items_processed}")
        print(f"✅ Stock Updated: {return_status == 'Received'}")
        print("=" * 80)

        return jsonify({
            'success': True,
            'message': 'Purchase return submitted successfully',
            'return_id': return_id,
            'return_total': return_total,
            'stock_updated': return_status == 'Received',
            'items_processed': items_processed
        }), 201

    except ValueError as err:
        conn.rollback()
        print(f"\n❌ Validation Error: {err}")
        return jsonify({'error': str(err)}), 400
    except Exception as e:
        conn.rollback()
        print(f"\n❌ Error: {e}")
        traceback.print_exc()
        return jsonify({'error': str(e)}), 500
    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()


# ============================================
# UPDATE PURCHASE RETURN (✅ BATCH-AWARE)
# ============================================
@purchase_return_bp.route('/update_purchase_return/<int:return_id>', methods=['PUT'])
@jwt_required()
@role_required('admin', 'manager')
def update_purchase_return(return_id):
    """
    ✅ BATCH-AWARE: Update Purchase Return with proper batch tracking
    """
    try:
        data = request.get_json(force=True)
        if not data:
            return jsonify({'error': 'No data provided'}), 400
    except Exception as e:
        return jsonify({'error': 'Invalid JSON', 'message': str(e)}), 400

    print("=" * 80)
    print(f"📝 UPDATING PURCHASE RETURN #{return_id} (BATCH-AWARE)")
    print("=" * 80)

    try:
        order_id = safe_int(data.get('order_id'))
        supplier_id = safe_int(data['supplier']['id']) if isinstance(data.get('supplier'), dict) else safe_int(data.get('supplier_id'))
        warehouse_id = safe_int(data.get('warehouse_id'))
        store_id = safe_int(data.get('store_id'))
        return_note = data.get('return_note', '')
        return_tax = safe_float(data.get('return_tax', 0.00))
        return_discount = safe_float(data.get('return_discount', 0.00))
        return_total = safe_float(data.get('return_total', 0.00))
        new_status = data.get('return_status', 'Pending')
        return_items = data.get('items', [])

        conn = get_db_connection()
        if not conn:
            return jsonify({'error': 'Database connection failed'}), 500

        cursor = conn.cursor(dictionary=True, buffered=True)
        conn.start_transaction()

        # Get existing return
        cursor.execute("""
            SELECT return_status, warehouse_id, store_id, order_id
            FROM purchase_returns 
            WHERE return_id = %s
        """, (return_id,))
        
        old_return = cursor.fetchone()
        if not old_return:
            raise ValueError(f'Return ID {return_id} not found')

        old_status = old_return['return_status']
        old_warehouse_id = old_return['warehouse_id']
        old_store_id = old_return['store_id']
        old_order_id = old_return['order_id']

        print(f"Old status: {old_status} → New: {new_status}")
        print(f"Purchase Order: #{old_order_id}")

        # ✅ BATCH-AWARE: Reverse old stock if was Received
        if old_status == 'Received':
            print("\n🔄 Reversing old stock changes (with batch tracking)...")
            
            cursor.execute("""
                SELECT product_id, variation_id, quantity
                FROM purchase_return_items
                WHERE return_id = %s
            """, (return_id,))
            
            old_items = cursor.fetchall()

            for item in old_items:
                product_id = item['product_id']
                variation_id = item['variation_id']
                quantity = safe_float(item['quantity'])
                
                restore_stock_with_batch_tracking(
                    cursor, old_store_id, old_warehouse_id, 
                    product_id, variation_id, quantity, old_order_id
                )

        # Delete old items
        cursor.execute("DELETE FROM purchase_return_items WHERE return_id = %s", (return_id,))

        # Update header
        cursor.execute("""
            UPDATE purchase_returns
            SET order_id=%s, supplier_id=%s, warehouse_id=%s, store_id=%s,
                return_note=%s, return_tax=%s, return_discount=%s, 
                return_total=%s, return_status=%s
            WHERE return_id=%s
        """, (
            order_id, supplier_id, warehouse_id, store_id,
            return_note, return_tax, return_discount,
            return_total, new_status, return_id
        ))

        # Insert new items
        current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')

        for item in return_items:
            product_id = safe_int(item.get('product_id'))
            variation_id = safe_int(item.get('variation_id')) if item.get('variation_id') else None
            quantity = safe_float(item.get('return_quantity', 0))
            
            discount_type = item.get('discount_type', 'fixed')
            product_discount = safe_float(item.get('product_discount', 0.00))
            tax_type = item.get('tax_type', 'exclusive')
            product_tax = safe_float(item.get('product_tax', 0.00))
            unit_price = safe_float(item.get('unit_price', 0.00))
            price = safe_float(item.get('price', 0.00))
            discount = safe_float(item.get('discount', 0.00))
            tax = safe_float(item.get('tax', 0.00))
            subtotal = safe_float(item.get('subtotal', 0.00))

            # Insert item
            cursor.execute("""
                INSERT INTO purchase_return_items (
                    return_id, product_id, variation_id, discount_type, product_discount,
                    tax_type, product_tax, quantity, unit_price, price,
                    discount, tax, subtotal, created_on
                ) VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)
            """, (
                return_id, product_id, variation_id, discount_type, product_discount,
                tax_type, product_tax, quantity, unit_price, price,
                discount, tax, subtotal, current_time
            ))

            # ✅ BATCH-AWARE: Deduct stock if Received
            if new_status == 'Received':
                # Get batches from THIS PO's GRN (FIFO)
                if variation_id:
                    cursor.execute("""
                        SELECT pb.batch_id, pb.remaining_quantity
                        FROM product_batches pb
                        INNER JOIN grn g ON pb.grn_id = g.grn_id
                        WHERE pb.variation_id = %s 
                        AND pb.remaining_quantity > 0
                        AND g.purchase_order_id = %s
                        AND g.status = 'completed'
                        AND g.warehouse_id = %s
                        AND g.store_id = %s
                        ORDER BY pb.batch_id ASC
                    """, (variation_id, order_id, warehouse_id, store_id))
                else:
                    cursor.execute("""
                        SELECT pb.batch_id, pb.remaining_quantity
                        FROM product_batches pb
                        INNER JOIN grn g ON pb.grn_id = g.grn_id
                        WHERE pb.product_id = %s 
                        AND pb.variation_id IS NULL 
                        AND pb.remaining_quantity > 0
                        AND g.purchase_order_id = %s
                        AND g.status = 'completed'
                        AND g.warehouse_id = %s
                        AND g.store_id = %s
                        ORDER BY pb.batch_id ASC
                    """, (product_id, order_id, warehouse_id, store_id))
                
                batches = cursor.fetchall()
                remaining_to_deduct = quantity

                # Process each batch with batch tracking
                for batch in batches:
                    if remaining_to_deduct <= 0:
                        break
                    
                    batch_id = batch['batch_id']
                    batch_remaining = safe_float(batch['remaining_quantity'])
                    deduct_from_this_batch = min(batch_remaining, remaining_to_deduct)
                    
                    # 1. Deduct from product_batches
                    cursor.execute("""
                        UPDATE product_batches
                        SET remaining_quantity = remaining_quantity - %s
                        WHERE batch_id = %s
                    """, (deduct_from_this_batch, batch_id))
                    
                    # 2. Deduct from warehouse_stock with THIS batch_id
                    cursor.execute("""
                        SELECT id, quantity 
                        FROM warehouse_stock
                        WHERE store_id = %s 
                          AND warehouse_id = %s
                          AND product_id = %s
                          AND (variation_id IS NULL AND %s IS NULL OR variation_id = %s)
                          AND batch_id = %s
                        LIMIT 1
                    """, (store_id, warehouse_id, product_id, variation_id, variation_id, batch_id))
                    
                    warehouse_record = cursor.fetchone()
                    
                    if warehouse_record:
                        current_warehouse_qty = safe_float(warehouse_record['quantity'])
                        deduct_from_warehouse = min(current_warehouse_qty, deduct_from_this_batch)
                        new_warehouse_qty = current_warehouse_qty - deduct_from_warehouse
                        
                        cursor.execute("""
                            UPDATE warehouse_stock
                            SET quantity = %s, updated_at = NOW()
                            WHERE id = %s
                        """, (new_warehouse_qty, warehouse_record['id']))
                        
                        remaining_to_deduct -= deduct_from_warehouse
                    else:
                        # Fallback to generic records
                        cursor.execute("""
                            SELECT id, quantity
                            FROM warehouse_stock
                            WHERE store_id = %s 
                              AND warehouse_id = %s
                              AND product_id = %s
                              AND (variation_id IS NULL AND %s IS NULL OR variation_id = %s)
                              AND quantity > 0
                            ORDER BY id ASC
                            LIMIT 1
                        """, (store_id, warehouse_id, product_id, variation_id, variation_id))
                        
                        fallback_record = cursor.fetchone()
                        
                        if fallback_record:
                            current_qty = safe_float(fallback_record['quantity'])
                            deduct_now = min(current_qty, deduct_from_this_batch)
                            new_qty = current_qty - deduct_now
                            
                            cursor.execute("""
                                UPDATE warehouse_stock
                                SET quantity = %s, updated_at = NOW()
                                WHERE id = %s
                            """, (new_qty, fallback_record['id']))
                            
                            remaining_to_deduct -= deduct_now

                if remaining_to_deduct > 0:
                    raise ValueError(f"Stock shortage for product {product_id} from Purchase Order #{order_id}")

        conn.commit()
        print("\n✅ Update committed successfully")

        return jsonify({
            'success': True,
            'message': 'Purchase return updated successfully',
            'return_id': return_id,
            'stock_updated': new_status == 'Received'
        }), 200

    except ValueError as err:
        conn.rollback()
        print(f"\n❌ Validation Error: {err}")
        return jsonify({'error': str(err)}), 400
    except Exception as e:
        conn.rollback()
        print(f"\n❌ Error: {e}")
        traceback.print_exc()
        return jsonify({'error': str(e)}), 500
    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()


# ============================================
# DELETE PURCHASE RETURN (✅ BATCH-AWARE)
# ============================================
@purchase_return_bp.route('/delete_purchase_return/<int:return_id>', methods=['DELETE'])
@jwt_required()
@role_required('admin', 'manager')
def delete_purchase_return(return_id):
    """
    ✅ BATCH-AWARE: Delete Purchase Return with proper batch tracking
    """
    conn = None
    cursor = None
    
    try:
        conn = get_db_connection()
        if not conn:
            return jsonify({'error': 'Database connection failed'}), 500

        cursor = conn.cursor(dictionary=True, buffered=True)
        conn.start_transaction()

        # Get return details
        cursor.execute("""
            SELECT return_status, warehouse_id, store_id, order_id
            FROM purchase_returns 
            WHERE return_id = %s
        """, (return_id,))
        
        result = cursor.fetchone()
        if not result:
            return jsonify({'error': f'Return ID {return_id} not found'}), 404

        return_status = result['return_status']
        warehouse_id = result['warehouse_id']
        store_id = result['store_id']
        order_id = result['order_id']

        # Get items
        cursor.execute("""
            SELECT product_id, variation_id, quantity
            FROM purchase_return_items
            WHERE return_id = %s
        """, (return_id,))
        
        returned_items = cursor.fetchall()

        # ✅ BATCH-AWARE: Reverse stock if Received
        if return_status == "Received":
            print(f"\n🔄 Reversing stock for deleted return (with batch tracking)...")
            
            for item in returned_items:
                product_id = item['product_id']
                variation_id = item['variation_id']
                return_qty = safe_float(item['quantity'])
                
                restore_stock_with_batch_tracking(
                    cursor, store_id, warehouse_id,
                    product_id, variation_id, return_qty, order_id
                )

        # Delete
        cursor.execute("DELETE FROM purchase_return_items WHERE return_id = %s", (return_id,))
        cursor.execute("DELETE FROM purchase_returns WHERE return_id = %s", (return_id,))

        conn.commit()
        
        return jsonify({
            'success': True,
            'message': f'Purchase Return {return_id} deleted successfully',
            'stock_restored': return_status == 'Received'
        }), 200

    except Exception as e:
        if conn:
            conn.rollback()
        print(f"❌ Error: {e}")
        traceback.print_exc()
        return jsonify({'error': str(e)}), 500
    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()


# ============================================
# GET PURCHASE RETURNS (WITH FILTERS)
# ============================================
@purchase_return_bp.route('/get_purchase_returns', methods=['GET'])
@jwt_required()
@role_required('admin', 'manager')
def get_purchase_returns():
    """Get all purchase returns with filtering"""
    conn = None
    cursor = None

    try:
        store_id = request.args.get('store_id')
        warehouse_id = request.args.get('warehouse_id')
        date_filter = request.args.get('dateFilter')
        start_date = request.args.get('startDate')
        end_date = request.args.get('endDate')
        search = request.args.get('search', '').strip()
        status = request.args.get('status', '').strip()
        page = int(request.args.get('page', 1))
        per_page = int(request.args.get('per_page', 20))

        offset = (page - 1) * per_page

        conn = get_db_connection()
        if not conn:
            return jsonify({'error': 'Database connection failed'}), 500

        cursor = conn.cursor(dictionary=True, buffered=True)

        query = """
            SELECT 
                pr.return_id, pr.order_id,
                po.order_id AS purchase_id, 
                s.supplier_name,
                pr.warehouse_id, w.warehouse_name,
                pr.store_id, st.store_name,
                pr.return_status, pr.return_total, 
                pr.return_tax, pr.return_discount,
                pr.created_on, 
                po.grand_total AS purchase_total,
                (po.grand_total - pr.return_total) AS due_amount
            FROM purchase_returns pr
            JOIN purchase_orders po ON pr.order_id = po.order_id
            JOIN suppliers s ON pr.supplier_id = s.id
            LEFT JOIN warehouses w ON pr.warehouse_id = w.id
            LEFT JOIN stores st ON pr.store_id = st.id
        """

        conditions = []
        params = []
        today = datetime.now().date()

        if store_id:
            conditions.append("pr.store_id = %s")
            params.append(store_id)

        if warehouse_id:
            conditions.append("pr.warehouse_id = %s")
            params.append(warehouse_id)

        # Date filters
        if date_filter == "today":
            conditions.append("DATE(pr.created_on) = %s")
            params.append(today)
        elif date_filter == "yesterday":
            conditions.append("DATE(pr.created_on) = %s")
            params.append(today - timedelta(days=1))
        elif date_filter == "thisWeek":
            start_of_week = today - timedelta(days=today.weekday())
            end_of_week = start_of_week + timedelta(days=6)
            conditions.append("DATE(pr.created_on) BETWEEN %s AND %s")
            params.extend([start_of_week, end_of_week])
        elif date_filter == "lastWeek":
            start_of_last_week = today - timedelta(days=today.weekday() + 7)
            end_of_last_week = start_of_last_week + timedelta(days=6)
            conditions.append("DATE(pr.created_on) BETWEEN %s AND %s")
            params.extend([start_of_last_week, end_of_last_week])
        elif date_filter == "thisMonth":
            start_of_month = today.replace(day=1)
            if today.month == 12:
                end_of_month = today.replace(year=today.year + 1, month=1, day=1) - timedelta(days=1)
            else:
                end_of_month = today.replace(month=today.month + 1, day=1) - timedelta(days=1)
            conditions.append("DATE(pr.created_on) BETWEEN %s AND %s")
            params.extend([start_of_month, end_of_month])
        elif date_filter == "lastMonth":
            first_day_this_month = today.replace(day=1)
            end_of_last_month = first_day_this_month - timedelta(days=1)
            start_of_last_month = end_of_last_month.replace(day=1)
            conditions.append("DATE(pr.created_on) BETWEEN %s AND %s")
            params.extend([start_of_last_month, end_of_last_month])
        elif date_filter == "customRange" and start_date and end_date:
            try:
                start = datetime.strptime(start_date, "%Y-%m-%d").date()
                end = datetime.strptime(end_date, "%Y-%m-%d").date()
                conditions.append("DATE(pr.created_on) BETWEEN %s AND %s")
                params.extend([start, end])
            except ValueError:
                return jsonify({"error": "Invalid custom date range"}), 400

        if search:
            conditions.append("(CAST(po.order_id AS CHAR) LIKE %s OR s.supplier_name LIKE %s)")
            params.extend([f"%{search}%", f"%{search}%"])

        if status:
            conditions.append("pr.return_status = %s")
            params.append(status)

        if conditions:
            query += " WHERE " + " AND ".join(conditions)

        count_query = f"SELECT COUNT(*) as total FROM ({query}) as subquery"
        cursor.execute(count_query, params)
        total = cursor.fetchone()['total']

        query += " ORDER BY pr.created_on DESC LIMIT %s OFFSET %s"
        params.extend([per_page, offset])

        cursor.execute(query, params)
        data = cursor.fetchall()

        for row in data:
            if row.get("created_on") and isinstance(row["created_on"], datetime):
                row["created_on"] = row["created_on"].strftime("%Y-%m-%d %H:%M:%S")

        return jsonify({
            'success': True,
            'returns': data,
            'pagination': {
                'page': page,
                'per_page': per_page,
                'total': total,
                'pages': (total + per_page - 1) // per_page
            }
        }), 200

    except Exception as e:
        print(f"❌ Error: {e}")
        traceback.print_exc()
        return jsonify({'error': str(e)}), 500
    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()


# ============================================
# GET SINGLE PURCHASE RETURN (✅ ENHANCED FOR EDITING)
# ============================================
@purchase_return_bp.route('/get_purchase_return/<int:return_id>', methods=['GET'])
@jwt_required()
@role_required('admin', 'manager')
def get_purchase_return(return_id):
    """
    Get single purchase return with complete details
    ✅ ENHANCED: Includes warehouse stock + already returned quantity for editing
    """
    conn = None
    cursor = None

    try:
        conn = get_db_connection()
        if not conn:
            return jsonify({'error': 'Database connection failed'}), 500

        cursor = conn.cursor(dictionary=True, buffered=True)

        # Fetch return header
        cursor.execute("""
            SELECT 
                pr.return_id, pr.order_id, pr.supplier_id,
                pr.warehouse_id, w.warehouse_name,
                pr.store_id, 
                st.store_name, st.email as store_email, 
                st.contact as store_contact, st.address as store_address,
                st.is_active as store_is_active,
                s.supplier_name, s.supplier_code, s.supplier_contact,
                s.supplier_email, s.supplier_optional_contact, s.supplier_address,
                pr.return_note, pr.return_tax, pr.return_discount,
                pr.return_total, pr.return_status, pr.created_on
            FROM purchase_returns pr
            LEFT JOIN suppliers s ON pr.supplier_id = s.id
            LEFT JOIN warehouses w ON pr.warehouse_id = w.id
            LEFT JOIN stores st ON pr.store_id = st.id
            WHERE pr.return_id = %s
        """, (return_id,))

        return_data = cursor.fetchone()
        if not return_data:
            return jsonify({'error': 'Purchase Return not found'}), 404

        # ✅ Get store and warehouse IDs for stock calculation
        store_id = return_data['store_id']
        warehouse_id = return_data['warehouse_id']
        order_id = return_data['order_id']

        # Fetch items WITH warehouse stock calculation
        cursor.execute("""
            SELECT 
                pri.return_item_id, pri.product_id, pri.variation_id,
                pri.quantity AS return_quantity,
                pri.price, pri.unit_price, pri.discount, pri.tax, pri.subtotal,
                pri.discount_type, pri.product_discount,
                pri.tax_type, pri.product_tax,
                p.product_name, p.sku, p.purchase_unit_id,
                pu.unit_short as purchase_unit,
                v.variation_name, v.variation_sku,
                oi.quantity AS order_quantity,
                oi.purchase_unit AS order_purchase_unit
            FROM purchase_return_items pri
            LEFT JOIN products p ON pri.product_id = p.id
            LEFT JOIN units pu ON p.purchase_unit_id = pu.id
            LEFT JOIN product_variations v ON pri.variation_id = v.id
            LEFT JOIN order_items oi 
                   ON oi.order_id = %s
                  AND oi.product_id = pri.product_id 
                  AND (oi.variation_id = pri.variation_id 
                       OR (oi.variation_id IS NULL AND pri.variation_id IS NULL))
            WHERE pri.return_id = %s
        """, (order_id, return_id))

        items = cursor.fetchall()

        # Format items with warehouse stock calculation
        formatted_items = []
        for item in items:
            product_id = item['product_id']
            variation_id = item['variation_id']
            return_quantity = float(item['return_quantity'])
            purchase_unit = item.get('order_purchase_unit') or item.get('purchase_unit') or ''

            # ✅ CRITICAL: Calculate current warehouse stock from THIS PO's GRN batches
            if variation_id:
                cursor.execute("""
                    SELECT COALESCE(SUM(pb.remaining_quantity), 0) as current_stock
                    FROM product_batches pb
                    INNER JOIN grn g ON pb.grn_id = g.grn_id
                    WHERE pb.variation_id = %s
                    AND g.purchase_order_id = %s
                    AND g.status = 'completed'
                    AND g.warehouse_id = %s
                    AND g.store_id = %s
                """, (variation_id, order_id, warehouse_id, store_id))
            else:
                cursor.execute("""
                    SELECT COALESCE(SUM(pb.remaining_quantity), 0) as current_stock
                    FROM product_batches pb
                    INNER JOIN grn g ON pb.grn_id = g.grn_id
                    WHERE pb.product_id = %s
                    AND pb.variation_id IS NULL
                    AND g.purchase_order_id = %s
                    AND g.status = 'completed'
                    AND g.warehouse_id = %s
                    AND g.store_id = %s
                """, (product_id, order_id, warehouse_id, store_id))
            
            stock_result = cursor.fetchone()
            current_warehouse_stock = safe_float(stock_result['current_stock']) if stock_result else 0.00

            # ✅ Available stock = current warehouse stock + already returned quantity
            available_stock = current_warehouse_stock + return_quantity

            print(f"📊 Edit Load - Product {product_id}:")
            print(f"   Current Warehouse Stock (from PO batches): {current_warehouse_stock}")
            print(f"   Already Returned Qty: {return_quantity}")
            print(f"   Available for Edit: {available_stock}")
            
            formatted_item = {
                'return_item_id': item['return_item_id'],
                'product_id': product_id,
                'variation_id': variation_id,
                'return_quantity': return_quantity,
                'price': float(item['price']),
                'unit_price': float(item['unit_price']),
                'discount': float(item['discount']),
                'tax': float(item['tax']),
                'subtotal': float(item['subtotal']),
                'discount_type': item['discount_type'],
                'product_discount': float(item['product_discount']),
                'tax_type': item['tax_type'],
                'product_tax': float(item['product_tax']),
                'product_name': item['product_name'],
                'sku': item['sku'],
                'variation_name': item['variation_name'],
                'variation_sku': item['variation_sku'],
                'order_quantity': float(item['order_quantity']) if item['order_quantity'] else 0,
                'purchase_unit': purchase_unit,
                # ✅ NEW: Stock information for editing
                'warehouse_quantity': current_warehouse_stock,
                'available_stock': available_stock
            }
            formatted_items.append(formatted_item)

        # Build response
        response = {
            'return_id': return_data['return_id'],
            'order_id': order_id,
            'supplier_id': return_data['supplier_id'],
            'supplier_name': return_data['supplier_name'],
            'supplier_code': return_data['supplier_code'],
            'supplier_contact': return_data['supplier_contact'],
            'supplier_email': return_data['supplier_email'],
            'supplier_optional_contact': return_data['supplier_optional_contact'],
            'supplier_address': return_data['supplier_address'],
            'warehouse_id': warehouse_id,
            'warehouse_name': return_data['warehouse_name'],
            'store_id': store_id,
            'store_name': return_data['store_name'],
            'store_email': return_data['store_email'],
            'store_contact': return_data['store_contact'],
            'store_address': return_data['store_address'],
            'store_is_active': return_data['store_is_active'],
            'return_note': return_data['return_note'],
            'return_tax': float(return_data['return_tax']),
            'return_discount': float(return_data['return_discount']),
            'return_total': float(return_data['return_total']),
            'return_status': return_data['return_status'],
            'created_on': return_data['created_on'].strftime('%Y-%m-%d %H:%M:%S') if return_data['created_on'] else None,
            'items': formatted_items
        }

        return jsonify(response), 200

    except Exception as e:
        print(f"❌ Error: {e}")
        traceback.print_exc()
        return jsonify({'error': str(e)}), 500
    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()


# ============================================
# CHECK IF PURCHASE RETURN EXISTS
# ============================================
@purchase_return_bp.route('/check_purchase_return/<int:order_id>', methods=['GET'])
@jwt_required()
@role_required('admin', 'manager', 'cashier')
def check_purchase_return(order_id):
    """Check if a purchase return exists for an order"""
    try:
        conn = get_db_connection()
        if not conn:
            return jsonify({'error': 'Database connection failed'}), 500

        cursor = conn.cursor(dictionary=True)

        cursor.execute("""
            SELECT return_id 
            FROM purchase_returns 
            WHERE order_id = %s 
            LIMIT 1
        """, (order_id,))
        
        purchase_return = cursor.fetchone()

        cursor.close()
        conn.close()

        if purchase_return:
            return jsonify({
                'success': True,
                'exists': True, 
                'purchase_return_id': purchase_return['return_id']
            }), 200
        else:
            return jsonify({
                'success': True,
                'exists': False
            }), 200

    except Exception as e:
        print(f"❌ Error: {e}")
        traceback.print_exc()
        return jsonify({'error': str(e)}), 500