from flask import Blueprint, request, jsonify
from flask_jwt_extended import jwt_required
from db.db import get_db_connection
from config.auth import role_required
from datetime import datetime, timedelta
import json
import mysql.connector
import traceback

purchase_bp = Blueprint('purchase', __name__)


def safe_float(value, default=0.0):
    """Safely convert to float"""
    try:
        return float(value) if value not in (None, '', 'null') else default
    except (ValueError, TypeError):
        return default


def safe_int(value, default=0):
    """Safely convert to int"""
    try:
        return int(value) if value not in (None, '', 'null') else default
    except (ValueError, TypeError):
        return default

@purchase_bp.route('/search_purchase_order_product', methods=['GET'])
@jwt_required()
@role_required('admin', 'manager')
def search_purchase_order_product():
    """
    Search products for Purchase Order with current stock information
    ✅ FIXED: Returns FULL unit names (not just short names)
    """
    
    try:
        search_query = request.args.get('query', '').strip()
        warehouse_id = request.args.get('warehouse_id')
        store_id = request.args.get('store_id')
        
        print("\n" + "=" * 80)
        print("🔍 SEARCH PURCHASE ORDER PRODUCT (FIXED - FULL UNIT NAMES)")
        print("=" * 80)
        print(f"Search Query: '{search_query}'")
        print(f"Warehouse ID: {warehouse_id}")
        print(f"Store ID: {store_id}")
        
        if not search_query or len(search_query) < 2:
            return jsonify({'error': 'Search query must be at least 2 characters'}), 400
        
        if not warehouse_id:
            return jsonify({'error': 'Warehouse ID is required'}), 400
        
        if not store_id:
            return jsonify({'error': 'Store ID is required'}), 400
        
        conn = get_db_connection()
        if not conn:
            return jsonify({'error': 'Database connection failed'}), 500
        
        cursor = conn.cursor(dictionary=True)
        result = []
        
        try:
            # Detect combined format
            main_name = None
            variation_name = None
            if " - " in search_query:
                parts = search_query.split(" - ", 1)
                main_name = parts[0].strip()
                variation_name = parts[1].strip()
            else:
                main_name = search_query
                variation_name = search_query
            
            # ✅ FIXED: Added unit_name columns to get FULL names
            cursor.execute("""
                SELECT 
                    p.id,
                    p.product_name,
                    p.sku,
                    p.product_type,
                    p.tax_type,
                    p.product_tax,
                    p.base_unit_id,
                    p.sale_unit_id,
                    p.purchase_unit_id,
                    bu.base_unit AS base_unit_name,
                    bu.base_unit AS base_unit_short,
                    su.unit_name AS sale_unit_name,
                    su.unit_short AS sale_unit_short,
                    pu.unit_name AS purchase_unit_name,
                    pu.unit_short AS purchase_unit_short
                FROM products p
                LEFT JOIN base_units bu ON p.base_unit_id = bu.id
                LEFT JOIN units su ON p.sale_unit_id = su.id
                LEFT JOIN units pu ON p.purchase_unit_id = pu.id
                WHERE p.product_name LIKE %s OR p.sku LIKE %s
            """, (f"%{main_name}%", f"%{main_name}%"))
            products = cursor.fetchall()
            
            print(f"✅ Found {len(products)} matching products")
            
            for product in products:
                # ============================================
                # VARIABLE PRODUCTS
                # ============================================
                if product['product_type'] == 'variable':
                    cursor.execute("""
                        SELECT 
                            pv.id,
                            pv.variation_name,
                            pv.variation_type,
                            pv.variation_sku,
                            pv.variation_cost,
                            pv.variation_price,
                            pv.variation_tax_type,
                            pv.variation_tax
                        FROM product_variations pv
                        WHERE pv.product_id = %s
                        AND (pv.variation_name LIKE %s 
                             OR pv.variation_type LIKE %s 
                             OR pv.variation_sku LIKE %s)
                    """, (product['id'], f"%{variation_name}%", f"%{variation_name}%", f"%{variation_name}%"))
                    variations = cursor.fetchall()
                    
                    print(f"  📦 Product: {product['product_name']} ({len(variations)} variations)")
                    
                    for var in variations:
                        # Get batches grouped by cost/price with GRN info
                        cursor.execute("""
                            SELECT 
                                pb.price,
                                pb.cost,
                                MAX(pb.expiration_date) AS latest_expiration_date,
                                SUM(ws.quantity) AS total_stock,
                                GROUP_CONCAT(DISTINCT pb.batch_id ORDER BY pb.batch_id SEPARATOR ',') AS batch_ids,
                                GROUP_CONCAT(DISTINCT CONCAT(
                                    COALESCE(g.grn_code, 'N/A'), 
                                    ' (', 
                                    COALESCE(DATE_FORMAT(g.grn_date, '%Y-%m-%d'), 'N/A'), 
                                    ')'
                                ) SEPARATOR ', ') AS grn_info
                            FROM product_batches pb
                            INNER JOIN warehouse_stock ws ON pb.batch_id = ws.batch_id
                            LEFT JOIN grn g ON pb.grn_id = g.grn_id
                            WHERE pb.product_id = %s 
                            AND pb.variation_id = %s 
                            AND pb.remaining_quantity > 0
                            AND ws.warehouse_id = %s
                            AND ws.store_id = %s
                            AND ws.quantity > 0
                            GROUP BY pb.price, pb.cost
                            ORDER BY pb.price ASC, pb.cost ASC
                        """, (product['id'], var['id'], warehouse_id, store_id))
                        
                        batches = cursor.fetchall()
                        
                        default_price = float(var['variation_price']) if var['variation_price'] else 0.0
                        default_cost = float(var['variation_cost']) if var['variation_cost'] else 0.0
                        
                        if batches:
                            for batch_group in batches:
                                price = float(batch_group['price']) if batch_group['price'] else default_price
                                cost = float(batch_group['cost']) if batch_group['cost'] else default_cost
                                stock = float(batch_group['total_stock']) if batch_group['total_stock'] else 0.0
                                exp_date = batch_group['latest_expiration_date'].strftime('%Y-%m-%d') if batch_group['latest_expiration_date'] else '-'
                                batch_ids = batch_group['batch_ids']
                                grn_info = batch_group.get('grn_info', 'N/A')
                                
                                tax_type = var.get('variation_tax_type') or product.get('tax_type')
                                product_tax = float(var.get('variation_tax', 0)) if var.get('variation_tax') else float(product.get('product_tax', 0))
                                
                                display_name = f"{product['product_name']} - {var['variation_name']} - (Cost {cost:.0f} - Price {price:.0f}) Stock {stock:.0f} Exp: {exp_date}"
                                
                                product_data = {
                                    "product_id": product['id'],
                                    "variation_id": var['id'],
                                    "batch_ids": batch_ids,
                                    "grn_info": grn_info,
                                    "product_name": product['product_name'],
                                    "variation_name": var['variation_name'],
                                    "variation_type": var['variation_type'],
                                    "sku": var['variation_sku'],
                                    "display_name": display_name,
                                    "product_type": "variation",
                                    "product_quantity": stock,
                                    "product_price": price,
                                    "product_cost": cost,
                                    "expiration_date": exp_date,
                                    "tax_type": tax_type,
                                    "product_tax": product_tax,
                                    # ✅ FIXED: Return BOTH full name and short name
                                    "base_unit": product.get('base_unit_short'),
                                    "base_unit_name": product.get('base_unit_name'),
                                    "sale_unit": product.get('sale_unit_short'),
                                    "sale_unit_name": product.get('sale_unit_name'),
                                    "purchase_unit": product.get('purchase_unit_short'),
                                    "purchase_unit_name": product.get('purchase_unit_name'),  # ✅ FULL NAME
                                }
                                
                                result.append(product_data)
                                
                                print(f"    ✅ {var['variation_name']} - Cost {cost:.0f}, Price {price:.0f}, Stock {stock:.0f}")
                                print(f"       Unit: {product.get('purchase_unit_name')} ({product.get('purchase_unit_short')})")
                        else:
                            display_name = f"{product['product_name']} - {var['variation_name']} - (Cost {default_cost:.0f} - Price {default_price:.0f}) Stock 0 Exp: -"
                            
                            product_data = {
                                "product_id": product['id'],
                                "variation_id": var['id'],
                                "batch_ids": None,
                                "grn_info": None,
                                "product_name": product['product_name'],
                                "variation_name": var['variation_name'],
                                "variation_type": var['variation_type'],
                                "sku": var['variation_sku'],
                                "display_name": display_name,
                                "product_type": "variation",
                                "product_quantity": 0.0,
                                "product_price": default_price,
                                "product_cost": default_cost,
                                "expiration_date": '-',
                                "tax_type": var.get('variation_tax_type') or product.get('tax_type'),
                                "product_tax": float(var.get('variation_tax', 0)) if var.get('variation_tax') else float(product.get('product_tax', 0)),
                                # ✅ FIXED: Return BOTH full name and short name
                                "base_unit": product.get('base_unit_short'),
                                "base_unit_name": product.get('base_unit_name'),
                                "sale_unit": product.get('sale_unit_short'),
                                "sale_unit_name": product.get('sale_unit_name'),
                                "purchase_unit": product.get('purchase_unit_short'),
                                "purchase_unit_name": product.get('purchase_unit_name'),  # ✅ FULL NAME
                            }
                            
                            result.append(product_data)
                            print(f"    ⚠️ {var['variation_name']} - No stock")
                
                # ============================================
                # SINGLE PRODUCTS
                # ============================================
                else:
                    if (main_name.lower() in product['product_name'].lower()) or \
                       (variation_name.lower() in product['product_name'].lower()) or \
                       (main_name.lower() in (product['sku'] or '').lower()):
                        
                        cursor.execute("""
                            SELECT 
                                pb.price,
                                pb.cost,
                                MAX(pb.expiration_date) AS latest_expiration_date,
                                SUM(ws.quantity) AS total_stock,
                                GROUP_CONCAT(DISTINCT pb.batch_id ORDER BY pb.batch_id SEPARATOR ',') AS batch_ids,
                                GROUP_CONCAT(DISTINCT CONCAT(
                                    COALESCE(g.grn_code, 'N/A'), 
                                    ' (', 
                                    COALESCE(DATE_FORMAT(g.grn_date, '%Y-%m-%d'), 'N/A'), 
                                    ')'
                                ) SEPARATOR ', ') AS grn_info
                            FROM product_batches pb
                            INNER JOIN warehouse_stock ws ON pb.batch_id = ws.batch_id
                            LEFT JOIN grn g ON pb.grn_id = g.grn_id
                            WHERE pb.product_id = %s 
                            AND pb.variation_id IS NULL 
                            AND pb.remaining_quantity > 0
                            AND ws.warehouse_id = %s
                            AND ws.store_id = %s
                            AND ws.quantity > 0
                            GROUP BY pb.price, pb.cost
                            ORDER BY pb.price ASC, pb.cost ASC
                        """, (product['id'], warehouse_id, store_id))
                        
                        batches = cursor.fetchall()
                        
                        if batches:
                            for batch_group in batches:
                                price = float(batch_group['price']) if batch_group['price'] else 0.0
                                cost = float(batch_group['cost']) if batch_group['cost'] else 0.0
                                stock = float(batch_group['total_stock']) if batch_group['total_stock'] else 0.0
                                exp_date = batch_group['latest_expiration_date'].strftime('%Y-%m-%d') if batch_group['latest_expiration_date'] else '-'
                                batch_ids = batch_group['batch_ids']
                                grn_info = batch_group.get('grn_info', 'N/A')
                                
                                display_name = f"{product['product_name']} ({product['sku']}) - (Cost {cost:.0f} - Price {price:.0f}) Stock {stock:.0f} Exp: {exp_date}"
                                
                                product_data = {
                                    "product_id": product['id'],
                                    "variation_id": None,
                                    "batch_ids": batch_ids,
                                    "grn_info": grn_info,
                                    "product_name": product['product_name'],
                                    "variation_name": None,
                                    "variation_type": None,
                                    "sku": product['sku'],
                                    "display_name": display_name,
                                    "product_type": "single",
                                    "product_quantity": stock,
                                    "product_price": price,
                                    "product_cost": cost,
                                    "expiration_date": exp_date,
                                    "tax_type": product['tax_type'],
                                    "product_tax": float(product['product_tax']) if product['product_tax'] else 0.0,
                                    # ✅ FIXED: Return BOTH full name and short name
                                    "base_unit": product.get('base_unit_short'),
                                    "base_unit_name": product.get('base_unit_name'),
                                    "sale_unit": product.get('sale_unit_short'),
                                    "sale_unit_name": product.get('sale_unit_name'),
                                    "purchase_unit": product.get('purchase_unit_short'),
                                    "purchase_unit_name": product.get('purchase_unit_name'),  # ✅ FULL NAME
                                }
                                
                                result.append(product_data)
                                
                                print(f"  ✅ {product['product_name']} - Cost {cost:.0f}, Price {price:.0f}, Stock {stock:.0f}")
                                print(f"     Unit: {product.get('purchase_unit_name')} ({product.get('purchase_unit_short')})")
                        else:
                            display_name = f"{product['product_name']} ({product['sku']}) - (Cost 0 - Price 0) Stock 0 Exp: -"
                            
                            product_data = {
                                "product_id": product['id'],
                                "variation_id": None,
                                "batch_ids": None,
                                "grn_info": None,
                                "product_name": product['product_name'],
                                "variation_name": None,
                                "variation_type": None,
                                "sku": product['sku'],
                                "display_name": display_name,
                                "product_type": "single",
                                "product_quantity": 0.0,
                                "product_price": 0.0,
                                "product_cost": 0.0,
                                "expiration_date": '-',
                                "tax_type": product['tax_type'],
                                "product_tax": float(product['product_tax']) if product['product_tax'] else 0.0,
                                # ✅ FIXED: Return BOTH full name and short name
                                "base_unit": product.get('base_unit_short'),
                                "base_unit_name": product.get('base_unit_name'),
                                "sale_unit": product.get('sale_unit_short'),
                                "sale_unit_name": product.get('sale_unit_name'),
                                "purchase_unit": product.get('purchase_unit_short'),
                                "purchase_unit_name": product.get('purchase_unit_name'),  # ✅ FULL NAME
                            }
                            
                            result.append(product_data)
                            print(f"  ⚠️ {product['product_name']} - No stock")
            
            print(f"\n✅ Returning {len(result)} product entries")
            print("=" * 80)
            
            return jsonify(result), 200
        
        except mysql.connector.Error as err:
            print(f"\n❌ Database Error: {err}")
            traceback.print_exc()
            return jsonify({'error': f'Database error: {str(err)}'}), 500
        
        except Exception as e:
            print(f"\n❌ Unexpected Error: {e}")
            traceback.print_exc()
            return jsonify({'error': f'Server error: {str(e)}'}), 500
        
        finally:
            if cursor:
                cursor.close()
            if conn:
                conn.close()
    
    except Exception as e:
        print(f"\n❌ Error in search_purchase_order_product: {e}")
        traceback.print_exc()
        return jsonify({'error': str(e)}), 500

@purchase_bp.route('/submit_order', methods=['POST'])
@jwt_required()
@role_required('admin', 'manager')
def submit_order():
    """
    Submit Purchase Order - CORRECT FLOW
    ✅ Only creates: products, purchase_orders, order_items
    ❌ Does NOT create: product_batches, warehouse_stock
    
    Batches and stock are created only when GRN is approved!
    """
    
    try:
        data = request.get_json(force=True)
        if not data:
            return jsonify({'error': 'No data provided'}), 400
    except Exception as e:
        return jsonify({'error': 'Invalid JSON data', 'message': str(e)}), 400
    
    print("=" * 80)
    print("📦 PURCHASE ORDER SUBMISSION - CORRECT FLOW")
    print("=" * 80)
    print(json.dumps(data, indent=2, default=str))
    print("=" * 80)
    
    # Extract data
    supplier_id = data.get('supplier')
    warehouse_id = data.get('warehouse_id')
    store_id = data.get('store_id', 1)
    note = data.get('note', '').strip()
    order_tax = safe_float(data.get('order_tax', 0))
    discount = safe_float(data.get('discount', 0))
    status = data.get('status', 'Ordered')  # Default to 'Ordered', not 'Received'
    payment_status = data.get('payment_status', 'Unpaid')
    payment_type = data.get('payment_type')
    grand_total = safe_float(data.get('grand_total', 0))
    products = data.get('products', [])
    
    # Validation
    if not supplier_id:
        return jsonify({'error': 'Supplier is required'}), 400
    
    if not warehouse_id:
        return jsonify({'error': 'Warehouse is required'}), 400
    
    if not products or len(products) == 0:
        return jsonify({'error': 'At least one product is required'}), 400
    
    if payment_status.lower() in ['paid', 'partial'] and not payment_type:
        return jsonify({'error': 'Payment type is required when payment status is Paid or Partial'}), 400
    
    conn = get_db_connection()
    if not conn:
        return jsonify({'error': 'Database connection failed'}), 500
    
    cursor = conn.cursor(dictionary=True)
    
    try:
        conn.start_transaction()
        print("\n🔄 Transaction started")
        
        # Calculate payment amounts
        print(f"\n📝 Step 1: Calculating payment amounts...")
        
        if payment_status.lower() == "paid":
            paid_amount = grand_total
            due_amount = 0.0
        elif payment_status.lower() == "unpaid":
            paid_amount = 0.0
            due_amount = grand_total
        else:  # Partial
            paid_amount = safe_float(data.get('paid_amount', 0))
            due_amount = grand_total - paid_amount
        
        print(f"  Grand Total: {grand_total}")
        print(f"  Paid: {paid_amount}")
        print(f"  Due: {due_amount}")
        print(f"  Payment Type: {payment_type}")
        
        # Validate supplier
        print(f"\n📝 Step 2: Validating supplier...")
        cursor.execute("SELECT id FROM suppliers WHERE id = %s", (supplier_id,))
        if not cursor.fetchone():
            raise ValueError(f"Supplier with ID {supplier_id} not found")
        print(f"✅ Supplier ID {supplier_id} validated")
        
        # Validate warehouse
        print(f"\n📝 Step 3: Validating warehouse...")
        cursor.execute("SELECT id FROM warehouses WHERE id = %s", (warehouse_id,))
        if not cursor.fetchone():
            raise ValueError(f"Warehouse with ID {warehouse_id} not found")
        print(f"✅ Warehouse ID {warehouse_id} validated")
        
        # Create purchase order
        print(f"\n📝 Step 4: Creating purchase order...")
        order_date = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        
        # ✅ CORRECT: Set grn_status to 'not_received' by default
        cursor.execute("""
            INSERT INTO purchase_orders (
                supplier_id, warehouse_id, store_id, note, 
                order_tax, discount, status, grn_status,
                payment_status, payment_type,
                grand_total, paid_amount, due_amount, created_on
            )
            VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
        """, (
            supplier_id, warehouse_id, store_id, note,
            order_tax, discount, status, 'not_received',  # ✅ GRN status = not_received
            payment_status, payment_type,
            grand_total, paid_amount, due_amount, order_date
        ))
        
        order_id = cursor.lastrowid
        print(f"✅ Purchase order created with ID: {order_id}")
        print(f"✅ GRN Status: not_received (waiting for goods receipt)")
        
        # Process products - ONLY create order_items
        print(f"\n📝 Step 5: Processing {len(products)} products...")
        items_processed = 0
        
        for idx, product in enumerate(products, 1):
            try:
                product_id = safe_int(product.get('product_id'))
                variation_id = safe_int(product.get('variation_id')) if product.get('variation_id') else None
                quantity = safe_float(product.get('quantity', 0))
                
                # Get product pricing
                product_cost = safe_float(product.get('unit_price', 0))      # Actual product cost
                display_price = safe_float(product.get('price', 0))          # Net unit cost
                selling_price = safe_float(product.get('batch_price', 0))    # ✅ Selling/retail price
                
                tax = safe_float(product.get('tax', 0))
                discount_amt = safe_float(product.get('discount', 0))
                subtotal = safe_float(product.get('subtotal', 0))
                purchase_unit = product.get('purchase_unit', '')
                discount_type = product.get('discount_type', 'fixed')
                product_discount = safe_float(product.get('product_discount', 0))
                tax_type = product.get('tax_type', 'exclusive')
                product_tax = safe_float(product.get('product_tax', 0))
                expiration_date = product.get('expiration_date') or None
                
                if quantity <= 0:
                    print(f"  ⚠️  Skipping product {product_id}: quantity is 0")
                    continue
                
                print(f"\n  📦 Product {idx}:")
                print(f"    Product ID: {product_id}")
                print(f"    Variation ID: {variation_id}")
                print(f"    Quantity: {quantity}")
                print(f"    Product Cost: {product_cost}")
                print(f"    Display Price: {display_price}")
                print(f"    Selling Price: {selling_price}")  # ✅ Log selling price
                
                # ✅ UPDATED: Insert order_items with selling_price
                cursor.execute("""
                    INSERT INTO order_items (
                        order_id, product_id, variation_id, discount_type, 
                        product_discount, tax_type, product_tax, quantity,
                        unit_price, net_unit_cost, selling_price, purchase_unit, 
                        discount, tax, subtotal, expiration_date, created_on
                    )
                    VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
                """, (
                    order_id, product_id, variation_id, discount_type,
                    product_discount, tax_type, product_tax, quantity,
                    product_cost,      # unit_price = actual product cost
                    display_price,     # net_unit_cost = display price
                    selling_price,     # selling_price = intended retail price
                    purchase_unit, 
                    discount_amt, tax, subtotal, expiration_date, order_date
                ))
                
                print(f"    ✅ Order item saved with selling price")
                
                # ❌ REMOVED: Batch creation (will be done in GRN approval)
                # ❌ REMOVED: Warehouse stock update (will be done in GRN approval)
                
                items_processed += 1
                
            except Exception as product_error:
                print(f"  ❌ Error processing product {idx}: {product_error}")
                traceback.print_exc()
                raise product_error
        
        # Commit transaction
        conn.commit()
        print(f"\n✅ Transaction committed successfully!")
        print(f"✅ Total items processed: {items_processed}")
        
        print("\n" + "=" * 80)
        print("✅ PURCHASE ORDER FLOW - SUMMARY")
        print("=" * 80)
        print(f"✅ Created Tables:")
        print(f"   - purchase_orders (ID: {order_id})")
        print(f"   - order_items ({items_processed} items)")
        print(f"\n❌ NOT Created (will be created on GRN approval):")
        print(f"   - product_batches")
        print(f"   - warehouse_stock")
        print(f"\n📊 Stock Status: 0 units (waiting for GRN approval)")
        print(f"🔄 Next Step: Create GRN from view_all_purchase.html")
        print("=" * 80)
        
        return jsonify({
            'success': True,
            'message': f'Purchase order {order_id} created successfully. Stock will be updated after GRN approval.',
            'order_id': order_id,
            'data': {
                'order_id': order_id,
                'status': status,
                'grn_status': 'not_received',
                'payment_status': payment_status,
                'payment_type': payment_type,
                'grand_total': grand_total,
                'paid_amount': paid_amount,
                'due_amount': due_amount,
                'items_processed': items_processed,
                'stock_updated': False,  # ✅ Stock NOT updated yet
                'next_step': 'Create and approve GRN to update stock'
            }
        }), 201
    
    except mysql.connector.IntegrityError as err:
        conn.rollback()
        print(f"\n❌ Database Integrity Error: {err}")
        traceback.print_exc()
        
        error_msg = str(err)
        if 'foreign key constraint' in error_msg.lower():
            return jsonify({'error': 'Invalid reference - check supplier, product, warehouse IDs'}), 400
        
        return jsonify({'error': f'Database integrity error: {error_msg}'}), 500
    
    except ValueError as err:
        conn.rollback()
        print(f"\n❌ Validation Error: {err}")
        return jsonify({'error': str(err)}), 400
    
    except mysql.connector.Error as err:
        conn.rollback()
        print(f"\n❌ Database Error: {err}")
        traceback.print_exc()
        return jsonify({'error': f'Database error: {str(err)}'}), 500
    
    except Exception as e:
        conn.rollback()
        print(f"\n❌ Unexpected Error: {e}")
        traceback.print_exc()
        return jsonify({'error': f'Server error: {str(e)}'}), 500
    
    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()
        print("\n🔒 Database connection closed\n")


@purchase_bp.route('/get_all_submit_order', methods=['GET'])
@jwt_required()
@role_required('admin', 'manager')
def get_all_submit_order():
    """
    Get all purchase orders with optional filters
    ✅ Shows GRN status to track goods receipt progress
    ✅ Now includes selling_price in response
    """
    conn = get_db_connection()
    cursor = conn.cursor(dictionary=True)

    try:
        # Get filter params
        date_filter = request.args.get('dateFilter')
        start_date = request.args.get('startDate')
        end_date = request.args.get('endDate')
        store_id = request.args.get('store_id')
        
        print(f"📋 Filters - Date: {date_filter}, Store: {store_id}")

        # Base query - ✅ Now includes grn_status and selling_price
        query = """
            SELECT 
                po.order_id, po.supplier_id, s.supplier_name,  
                po.note, po.order_tax, po.discount, po.status, po.grn_status,
                po.grand_total, po.paid_amount, po.due_amount,
                po.payment_status, po.payment_type,
                po.warehouse_id, po.store_id,
                w.warehouse_name, st.store_name,
                po.created_on AS order_date,
                oi.product_id, oi.variation_id, oi.quantity,
                oi.unit_price, oi.net_unit_cost, oi.selling_price, oi.purchase_unit,
                oi.discount AS item_discount, oi.tax AS item_tax,
                oi.subtotal, oi.discount_type, oi.product_discount,
                oi.tax_type, oi.product_tax, oi.expiration_date,
                p.product_name, p.sku,
                v.variation_name, v.variation_type, v.variation_sku
            FROM purchase_orders po
            LEFT JOIN order_items oi ON po.order_id = oi.order_id
            LEFT JOIN products p ON oi.product_id = p.id
            LEFT JOIN suppliers s ON po.supplier_id = s.id
            LEFT JOIN warehouses w ON po.warehouse_id = w.id
            LEFT JOIN stores st ON po.store_id = st.id
            LEFT JOIN product_variations v ON oi.variation_id = v.id
        """

        # Build WHERE conditions
        conditions = []
        params = []
        today = datetime.now().date()

        # Store filter
        if store_id:
            try:
                store_id_int = int(store_id)
                conditions.append("po.store_id = %s")
                params.append(store_id_int)
                print(f"✅ Filtering by store_id: {store_id_int}")
            except ValueError:
                return jsonify({"error": "Invalid store_id format"}), 400

        # Date filters
        if date_filter == "today":
            conditions.append("DATE(po.created_on) = %s")
            params.append(today)
        elif date_filter == "yesterday":
            conditions.append("DATE(po.created_on) = %s")
            params.append(today - timedelta(days=1))
        elif date_filter == "thisWeek":
            start_of_week = today - timedelta(days=today.weekday())
            end_of_week = start_of_week + timedelta(days=6)
            conditions.append("DATE(po.created_on) BETWEEN %s AND %s")
            params.extend([start_of_week, end_of_week])
        elif date_filter == "lastWeek":
            start_of_last_week = today - timedelta(days=today.weekday() + 7)
            end_of_last_week = start_of_last_week + timedelta(days=6)
            conditions.append("DATE(po.created_on) BETWEEN %s AND %s")
            params.extend([start_of_last_week, end_of_last_week])
        elif date_filter == "thisMonth":
            start_of_month = today.replace(day=1)
            if today.month == 12:
                end_of_month = today.replace(year=today.year + 1, month=1, day=1) - timedelta(days=1)
            else:
                end_of_month = today.replace(month=today.month + 1, day=1) - timedelta(days=1)
            conditions.append("DATE(po.created_on) BETWEEN %s AND %s")
            params.extend([start_of_month, end_of_month])
        elif date_filter == "lastMonth":
            first_day_this_month = today.replace(day=1)
            end_of_last_month = first_day_this_month - timedelta(days=1)
            start_of_last_month = end_of_last_month.replace(day=1)
            conditions.append("DATE(po.created_on) BETWEEN %s AND %s")
            params.extend([start_of_last_month, end_of_last_month])
        elif date_filter == "customRange" and start_date and end_date:
            try:
                start = datetime.strptime(start_date, "%Y-%m-%d").date()
                end = datetime.strptime(end_date, "%Y-%m-%d").date()
                conditions.append("DATE(po.created_on) BETWEEN %s AND %s")
                params.extend([start, end])
            except ValueError:
                return jsonify({"error": "Invalid custom date range format"}), 400

        # Apply WHERE clause
        if conditions:
            query += " WHERE " + " AND ".join(conditions)

        query += " ORDER BY po.order_id DESC"

        print(f"📊 Executing query with params: {params}")
        cursor.execute(query, params)
        rows = cursor.fetchall()
        
        print(f"✅ Found {len(rows)} order items")

        # Build structured response
        orders = {}
        for row in rows:
            order_id = row["order_id"]

            if order_id not in orders:
                grand_total = float(row["grand_total"] or 0)
                paid_amount = float(row["paid_amount"] or 0)
                due_amount = float(row["due_amount"] or 0)
                payment_status = row["payment_status"] or "Unpaid"

                # Recalculate due amount
                if payment_status == "Paid":
                    due_amount = 0
                    paid_amount = grand_total
                elif payment_status == "Unpaid":
                    due_amount = grand_total
                    paid_amount = 0
                else:  # Partial
                    due_amount = max(grand_total - paid_amount, 0)

                orders[order_id] = {
                    "order_id": order_id,
                    "supplier_id": row["supplier_id"],
                    "supplier_name": row["supplier_name"],
                    "warehouse_id": row["warehouse_id"],
                    "warehouse_name": row["warehouse_name"],
                    "store_id": row["store_id"],
                    "store_name": row["store_name"],
                    "note": row["note"],
                    "order_tax": float(row["order_tax"] or 0),
                    "discount": float(row["discount"] or 0),
                    "status": row["status"],
                    "grn_status": row["grn_status"],  # ✅ Added grn_status
                    "grand_total": grand_total,
                    "paid_amount": paid_amount,
                    "due_amount": due_amount,
                    "payment_status": payment_status,
                    "payment_type": row["payment_type"],
                    "order_date": row["order_date"].strftime("%Y-%m-%d %H:%M:%S") if row["order_date"] else None,
                    "items": []
                }

            # Add product items
            if row["product_id"]:
                item = {
                    "product_id": row["product_id"],
                    "product_name": row["product_name"],
                    "sku": row["sku"],
                    "variation_id": row["variation_id"],
                    "variation_name": row["variation_name"],
                    "variation_type": row["variation_type"],
                    "variation_sku": row["variation_sku"],
                    "quantity": float(row["quantity"] or 0),
                    "unit_price": float(row["unit_price"] or 0),
                    "net_unit_cost": float(row["net_unit_cost"] or 0),
                    "selling_price": float(row["selling_price"] or 0),  # ✅ Added selling_price
                    "purchase_unit": row["purchase_unit"],
                    "discount": float(row["item_discount"] or 0),
                    "discount_type": row["discount_type"],
                    "product_discount": float(row["product_discount"] or 0),
                    "tax": float(row["item_tax"] or 0),
                    "tax_type": row["tax_type"],
                    "product_tax": float(row["product_tax"] or 0),
                    "subtotal": float(row["subtotal"] or 0),
                    "expiration_date": row["expiration_date"].strftime("%Y-%m-%d") if row["expiration_date"] else None
                }
                orders[order_id]["items"].append(item)

        print(f"✅ Returning {len(orders)} unique orders")
        return jsonify(list(orders.values())), 200

    except Exception as e:
        print("❌ Error in get_all_submit_order:", str(e))
        traceback.print_exc()
        return jsonify({"error": str(e)}), 500

    finally:
        cursor.close()
        conn.close()


@purchase_bp.route('/get_purchase/<int:purchase_id>', methods=['GET'])
@jwt_required()
@role_required('admin', 'manager')
def get_purchase(purchase_id):
    """
    Get single purchase order details
    ✅ FIXED: Removed batch join to prevent duplicates
    ✅ Shows only order_items data (no batch duplication)
    """
    conn = get_db_connection()
    if conn is None:
        return jsonify({"error": "Database connection failed"}), 500

    cursor = conn.cursor(dictionary=True)

    try:
        # ===========================
        # Step 1: Get Purchase Order Header
        # ===========================
        cursor.execute("""
            SELECT 
                po.order_id, po.supplier_id, po.warehouse_id, po.store_id,
                po.note, po.order_tax, po.discount, po.status, po.grn_status,
                po.grand_total, po.paid_amount, po.due_amount,
                po.payment_status, po.payment_type,
                po.created_on AS order_date,
                s.supplier_name, s.supplier_code, s.supplier_contact, 
                s.supplier_email, s.supplier_optional_contact, s.supplier_address,
                w.warehouse_name, st.store_name
            FROM purchase_orders po
            LEFT JOIN suppliers s ON po.supplier_id = s.id
            LEFT JOIN warehouses w ON po.warehouse_id = w.id
            LEFT JOIN stores st ON po.store_id = st.id
            WHERE po.order_id = %s
        """, (purchase_id,))

        order_header = cursor.fetchone()

        if not order_header:
            return jsonify({"error": "Order not found"}), 404

        # ===========================
        # Step 2: Get Order Items (WITHOUT batch join to avoid duplicates)
        # ===========================
        cursor.execute("""
            SELECT 
                oi.item_id, oi.product_id, oi.variation_id,
                oi.quantity, oi.unit_price, oi.net_unit_cost, oi.selling_price, oi.purchase_unit,
                oi.discount AS item_discount, oi.discount_type, oi.product_discount,
                oi.tax AS item_tax, oi.tax_type, oi.product_tax,
                oi.subtotal, oi.expiration_date,
                p.product_name, p.sku,
                pv.variation_name, pv.variation_type, pv.variation_sku,
                -- ✅ Get total warehouse stock (aggregated across all batches)
                COALESCE(
                    (SELECT SUM(ws.quantity) 
                     FROM warehouse_stock ws
                     INNER JOIN product_batches pb2 ON ws.batch_id = pb2.batch_id
                     WHERE pb2.product_id = oi.product_id 
                     AND (pb2.variation_id = oi.variation_id OR (pb2.variation_id IS NULL AND oi.variation_id IS NULL))
                     AND ws.warehouse_id = %s
                    ), 0
                ) AS product_quantity
            FROM order_items oi
            LEFT JOIN products p ON oi.product_id = p.id
            LEFT JOIN product_variations pv ON oi.variation_id = pv.id
            WHERE oi.order_id = %s
            ORDER BY oi.item_id
        """, (order_header["warehouse_id"], purchase_id))

        items = cursor.fetchall()

        # ===========================
        # Step 3: Build Response
        # ===========================
        order = {
            "order_id": order_header["order_id"],
            "supplier": {
                "id": order_header["supplier_id"],
                "supplier_name": order_header["supplier_name"],
                "supplier_code": order_header["supplier_code"],
                "supplier_contact": order_header["supplier_contact"],
                "supplier_email": order_header["supplier_email"],
                "supplier_optional_contact": order_header["supplier_optional_contact"],
                "supplier_address": order_header["supplier_address"]
            },
            "location": {
                "warehouse_id": order_header["warehouse_id"],
                "warehouse_name": order_header["warehouse_name"],
                "store_id": order_header["store_id"],
                "store_name": order_header["store_name"]
            },
            "note": order_header["note"],
            "order_tax": float(order_header["order_tax"] or 0),
            "discount": float(order_header["discount"] or 0),
            "status": order_header["status"],
            "grn_status": order_header["grn_status"],
            "grand_total": float(order_header["grand_total"] or 0),
            "paid_amount": float(order_header["paid_amount"] or 0),
            "due_amount": float(order_header["due_amount"] or 0),
            "payment_status": order_header["payment_status"],
            "payment_type": order_header["payment_type"],
            "order_date": order_header["order_date"].strftime("%Y-%m-%d %H:%M:%S") if order_header["order_date"] else None,
            "items": []
        }

        # Add items (no duplicates since we removed batch join)
        for item in items:
            order["items"].append({
                "item_id": item["item_id"],
                "product_id": item["product_id"],
                "product_name": item["product_name"],
                "sku": item["sku"],
                "variation_id": item["variation_id"],
                "variation_name": item["variation_name"],
                "variation_type": item["variation_type"],
                "variation_sku": item["variation_sku"],
                "quantity": float(item["quantity"] or 0),
                "unit_price": float(item["unit_price"] or 0),
                "net_unit_cost": float(item["net_unit_cost"] or 0),
                "selling_price": float(item["selling_price"] or 0),
                "purchase_unit": item["purchase_unit"],
                "discount": float(item["item_discount"] or 0),
                "discount_type": item["discount_type"],
                "product_discount": float(item["product_discount"] or 0),
                "tax": float(item["item_tax"] or 0),
                "tax_type": item["tax_type"],
                "product_tax": float(item["product_tax"] or 0),
                "subtotal": float(item["subtotal"] or 0),
                "expiration_date": item["expiration_date"].strftime("%Y-%m-%d") if item["expiration_date"] else None,
                # ✅ Total warehouse stock (aggregated, not per batch)
                "product_quantity": float(item["product_quantity"] or 0)
            })

        print(f"✅ Successfully fetched purchase order {purchase_id}")
        print(f"   Items: {len(order['items'])} (no duplicates)")
        
        return jsonify(order), 200

    except Exception as e:
        print(f"❌ Error fetching purchase {purchase_id}:", str(e))
        traceback.print_exc()
        return jsonify({"error": str(e)}), 500

    finally:
        cursor.close()
        conn.close()


@purchase_bp.route('/delete_order/<int:order_id>', methods=['DELETE'])
@jwt_required()
@role_required('admin')
def delete_order(order_id):
    """
    Delete purchase order
    ✅ CORRECT: Since batches/stock are created via GRN, we only need to:
    1. Delete order_items
    2. Delete purchase_order
    
    Batches and stock are handled by GRN deletion (if GRN exists)
    """
    conn = None
    cursor = None
    try:
        conn = get_db_connection()
        cursor = conn.cursor(dictionary=True)

        # Step 1: Check if order exists
        cursor.execute("""
            SELECT order_id, grn_status 
            FROM purchase_orders 
            WHERE order_id = %s
        """, (order_id,))
        order = cursor.fetchone()
        
        if not order:
            return jsonify({"error": "Order not found"}), 404

        grn_status = order['grn_status']

        # Step 2: Check if GRN exists
        if grn_status != 'not_received':
            return jsonify({
                "error": "Cannot delete order with GRN. Please delete GRN first.",
                "grn_status": grn_status
            }), 400

        # Step 3: Delete order items
        cursor.execute("DELETE FROM order_items WHERE order_id = %s", (order_id,))
        print(f"✅ Deleted order items for PO {order_id}")

        # Step 4: Delete purchase order
        cursor.execute("DELETE FROM purchase_orders WHERE order_id = %s", (order_id,))
        print(f"✅ Deleted purchase order {order_id}")

        conn.commit()
        print(f"✅ Purchase order {order_id} deleted successfully")
        
        return jsonify({
            "message": "Purchase order deleted successfully",
            "order_id": order_id
        }), 200

    except Exception as e:
        if conn:
            conn.rollback()
        print(f"❌ Error deleting purchase order {order_id}: {e}")
        traceback.print_exc()
        return jsonify({"error": str(e)}), 500

    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()


@purchase_bp.route('/update_order/<int:order_id>', methods=['PUT'])
@jwt_required()
@role_required('admin', 'manager')
def update_order(order_id):
    """
    Update Purchase Order - PRESERVES item_ids
    ✅ Uses UPDATE for existing items (keeps item_id 413 as 413)
    ✅ Inserts only truly new items
    ✅ Deletes only removed items
    ❌ Does NOT blanket DELETE all order_items anymore
    """
    try:
        data = request.get_json(force=True)
        if not data:
            return jsonify({'error': 'No data provided'}), 400
    except Exception as e:
        return jsonify({'error': 'Invalid JSON data', 'message': str(e)}), 400

    print("=" * 80)
    print("📝 PURCHASE ORDER UPDATE - PRESERVING item_ids")
    print("=" * 80)

    supplier_id   = data.get('supplier')
    warehouse_id  = data.get('warehouse_id')
    store_id      = data.get('store_id', 1)
    note          = data.get('note', '').strip()
    order_tax     = safe_float(data.get('order_tax', 0))
    discount      = safe_float(data.get('discount', 0))
    status        = data.get('status', 'Ordered')
    payment_status = data.get('payment_status', 'Unpaid')
    payment_type  = data.get('payment_type')
    grand_total   = safe_float(data.get('grand_total', 0))
    products      = data.get('products', [])

    # ── Validation ─────────────────────────────────────────────────────────────
    if not supplier_id:
        return jsonify({'error': 'Supplier is required'}), 400
    if not warehouse_id:
        return jsonify({'error': 'Warehouse is required'}), 400
    if not products:
        return jsonify({'error': 'At least one product is required'}), 400
    if payment_status.lower() in ['paid', 'partial'] and not payment_type:
        return jsonify({'error': 'Payment type is required when status is Paid or Partial'}), 400

    # ── Payment amounts ─────────────────────────────────────────────────────────
    if payment_status.lower() == 'paid':
        paid_amount = grand_total
        due_amount  = 0.0
    elif payment_status.lower() == 'unpaid':
        paid_amount = 0.0
        due_amount  = grand_total
    elif payment_status.lower() == 'partial':
        paid_amount = safe_float(data.get('paid_amount', 0))
        if paid_amount > grand_total:
            return jsonify({'error': 'Partial payment cannot exceed grand total'}), 400
        if paid_amount <= 0:
            return jsonify({'error': 'Partial payment must be greater than 0'}), 400
        due_amount = grand_total - paid_amount
    else:
        paid_amount = 0.0
        due_amount  = grand_total

    conn = get_db_connection()
    if not conn:
        return jsonify({'error': 'Database connection failed'}), 500

    cursor = conn.cursor(dictionary=True)

    try:
        conn.start_transaction()
        print("\n🔄 Transaction started")

        # ── Step 1: Verify order exists and has no GRN ─────────────────────────
        print(f"\n📝 Step 1: Verifying order {order_id}...")
        cursor.execute("""
            SELECT order_id, status, grn_status
            FROM purchase_orders
            WHERE order_id = %s
        """, (order_id,))
        old_order = cursor.fetchone()

        if not old_order:
            raise ValueError(f"Order ID {order_id} not found")

        if old_order['grn_status'] != 'not_received':
            raise ValueError(
                f"Cannot update order with GRN status '{old_order['grn_status']}'. "
                "Delete GRN first or create a new purchase order."
            )
        print(f"  ✅ Order verified - GRN Status: {old_order['grn_status']}")

        # ── Step 2: Load existing item_ids from DB ─────────────────────────────
        print(f"\n📝 Step 2: Loading existing order items...")
        cursor.execute("""
            SELECT item_id, product_id, variation_id
            FROM order_items
            WHERE order_id = %s
        """, (order_id,))
        existing_rows = cursor.fetchall()

        # Build a map: (product_id, variation_id) -> item_id
        # If duplicates exist, keep a list so we can pop one per match
        existing_map = {}
        for row in existing_rows:
            key = (row['product_id'], row['variation_id'])
            existing_map.setdefault(key, []).append(row['item_id'])

        print(f"  ✅ Found {len(existing_rows)} existing items")
        for key, ids in existing_map.items():
            print(f"     product_id={key[0]}, variation_id={key[1]} → item_ids={ids}")

        # ── Step 3: Update purchase_orders header ──────────────────────────────
        print(f"\n📝 Step 3: Updating purchase order header...")
        order_date = datetime.now().strftime('%Y-%m-%d %H:%M:%S')

        cursor.execute("""
            UPDATE purchase_orders
            SET supplier_id=%s, warehouse_id=%s, store_id=%s, note=%s,
                order_tax=%s, discount=%s, status=%s,
                payment_status=%s, payment_type=%s,
                grand_total=%s, paid_amount=%s, due_amount=%s,
                updated_on=%s
            WHERE order_id=%s
        """, (
            supplier_id, warehouse_id, store_id, note,
            order_tax, discount, status,
            payment_status, payment_type,
            grand_total, paid_amount, due_amount,
            order_date, order_id
        ))
        print(f"  ✅ Header updated")

        # ── Step 4: Smart UPDATE / INSERT per incoming product ─────────────────
        print(f"\n📝 Step 4: Processing {len(products)} incoming products...")

        used_item_ids   = set()   # item_ids we matched and updated
        items_processed = 0

        for idx, product in enumerate(products, 1):
            product_id   = safe_int(product.get('product_id'))
            variation_id = safe_int(product.get('variation_id')) if product.get('variation_id') else None
            quantity     = safe_float(product.get('quantity', 0))

            if quantity <= 0:
                print(f"    ⚠️  Skipping product {product_id}: quantity is 0")
                continue

            # Pricing
            display_price  = safe_float(product.get('price', 0))        # net_unit_cost
            actual_cost    = safe_float(product.get('unit_price', 0))   # unit_price
            selling_price  = safe_float(product.get('batch_price', 0))  # selling_price

            tax            = safe_float(product.get('tax', 0))
            discount_amt   = safe_float(product.get('discount', 0))
            subtotal       = safe_float(product.get('subtotal', 0))
            discount_type  = product.get('discount_type', 'fixed')
            product_disc   = safe_float(product.get('product_discount', 0))
            tax_type       = product.get('tax_type', 'exclusive')
            product_tax    = safe_float(product.get('product_tax', 0))
            purchase_unit  = product.get('purchase_unit', '')
            expiration_date = product.get('expiration_date') or None

            # Try to match an existing item by (product_id, variation_id)
            key = (product_id, variation_id)
            existing_ids = existing_map.get(key, [])

            if existing_ids:
                # ✅ REUSE the existing item_id — UPDATE in place
                item_id = existing_ids.pop(0)   # consume one match
                if not existing_ids:
                    del existing_map[key]

                used_item_ids.add(item_id)

                cursor.execute("""
                    UPDATE order_items
                    SET discount_type=%s, product_discount=%s,
                        tax_type=%s, product_tax=%s,
                        quantity=%s,
                        unit_price=%s, net_unit_cost=%s, selling_price=%s,
                        purchase_unit=%s,
                        discount=%s, tax=%s, subtotal=%s,
                        expiration_date=%s
                    WHERE item_id=%s AND order_id=%s
                """, (
                    discount_type, product_disc,
                    tax_type, product_tax,
                    quantity,
                    actual_cost, display_price, selling_price,
                    purchase_unit,
                    discount_amt, tax, subtotal,
                    expiration_date,
                    item_id, order_id
                ))

                print(f"\n    ✅ UPDATED item_id={item_id} "
                      f"(product_id={product_id}, variation_id={variation_id}) qty={quantity}")

            else:
                # 🆕 Truly new product — INSERT and get a new item_id
                cursor.execute("""
                    INSERT INTO order_items (
                        order_id, product_id, variation_id,
                        discount_type, product_discount,
                        tax_type, product_tax,
                        quantity, unit_price, net_unit_cost, selling_price,
                        purchase_unit, discount, tax, subtotal,
                        expiration_date, created_on
                    ) VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)
                """, (
                    order_id, product_id, variation_id,
                    discount_type, product_disc,
                    tax_type, product_tax,
                    quantity, actual_cost, display_price, selling_price,
                    purchase_unit, discount_amt, tax, subtotal,
                    expiration_date, order_date
                ))
                new_item_id = cursor.lastrowid
                used_item_ids.add(new_item_id)

                print(f"\n    🆕 INSERTED new item_id={new_item_id} "
                      f"(product_id={product_id}, variation_id={variation_id}) qty={quantity}")

            items_processed += 1

        # ── Step 5: Delete items that were removed from the order ──────────────
        # Any item_id still in existing_map was NOT sent back → user deleted it
        orphaned_ids = []
        for ids in existing_map.values():
            orphaned_ids.extend(ids)

        if orphaned_ids:
            fmt = ','.join(['%s'] * len(orphaned_ids))
            cursor.execute(f"""
                DELETE FROM order_items
                WHERE item_id IN ({fmt}) AND order_id = %s
            """, (*orphaned_ids, order_id))
            print(f"\n    🗑️  Deleted {len(orphaned_ids)} removed item(s): {orphaned_ids}")
        else:
            print(f"\n    ℹ️  No items removed")

        # ── Commit ─────────────────────────────────────────────────────────────
        conn.commit()
        print(f"\n✅ Transaction committed | items processed={items_processed} | deleted={len(orphaned_ids)}")

        return jsonify({
            'success': True,
            'message': f'Purchase order {order_id} updated. Existing item_ids preserved.',
            'order_id': order_id,
            'data': {
                'order_id':        order_id,
                'status':          status,
                'grn_status':      'not_received',
                'payment_status':  payment_status,
                'payment_type':    payment_type,
                'grand_total':     grand_total,
                'paid_amount':     paid_amount,
                'due_amount':      due_amount,
                'items_processed': items_processed,
                'items_deleted':   len(orphaned_ids),
                'stock_updated':   False,
            }
        }), 200

    except mysql.connector.IntegrityError as err:
        conn.rollback()
        print(f"\n❌ DB Integrity Error: {err}")
        traceback.print_exc()
        return jsonify({'error': f'Database integrity error: {str(err)}'}), 500

    except ValueError as err:
        conn.rollback()
        print(f"\n❌ Validation Error: {err}")
        return jsonify({'error': str(err)}), 400

    except mysql.connector.Error as err:
        conn.rollback()
        print(f"\n❌ DB Error: {err}")
        traceback.print_exc()
        return jsonify({'error': f'Database error: {str(err)}'}), 500

    except Exception as e:
        conn.rollback()
        print(f"\n❌ Unexpected Error: {e}")
        traceback.print_exc()
        return jsonify({'error': f'Server error: {str(e)}'}), 500

    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()
        print("\n🔒 DB connection closed\n")

@purchase_bp.route('/get_order_items', methods=['GET'])
@jwt_required()
@role_required('admin', 'manager')
def get_order_items():
    """
    Get order items for validation
    ✅ Simple helper endpoint - no changes needed
    """
    order_id = request.args.get('order_id')

    if not order_id:
        return jsonify({"error": "Missing order_id"}), 400

    try:
        conn = get_db_connection()
        cursor = conn.cursor(dictionary=True)

        cursor.execute("""
            SELECT oi.product_id, oi.variation_id
            FROM order_items oi
            WHERE oi.order_id = %s
        """, (order_id,))
        items = cursor.fetchall()

        return jsonify({"items": items}), 200

    except Exception as e:
        return jsonify({"error": str(e)}), 500

    finally:
        cursor.close()
        conn.close()


@purchase_bp.route('/get_warehouse_stock', methods=['GET'])
@jwt_required()
@role_required('admin', 'manager')
def get_warehouse_stock():
    """
    Get warehouse stock for a specific product/variation in a specific store and warehouse
    
    ✅ This is CORRECT - queries actual warehouse_stock table
    Stock only exists after GRN approval
    
    Query params:
    - product_id (required)
    - variation_id (optional)
    - store_id (required)
    - warehouse_id (required)
    
    Returns:
    {
        "success": True,
        "stock": 150.0,  # ✅ This will be 0 until GRN is approved
        "product_id": "123",
        "variation_id": "45",
        "store_id": "2",
        "warehouse_id": "1"
    }
    """
    try:
        product_id = request.args.get('product_id')
        variation_id = request.args.get('variation_id')
        store_id = request.args.get('store_id')
        warehouse_id = request.args.get('warehouse_id')
        
        if not product_id:
            return jsonify({"error": "product_id is required"}), 400
        
        if not store_id:
            return jsonify({"error": "store_id is required"}), 400
            
        if not warehouse_id:
            return jsonify({"error": "warehouse_id is required"}), 400
        
        conn = get_db_connection()
        if not conn:
            return jsonify({"error": "Database connection failed"}), 500
        
        cursor = conn.cursor(dictionary=True)
        
        # ✅ Query actual warehouse_stock table
        if variation_id and variation_id != 'null' and variation_id != '0':
            cursor.execute("""
                SELECT COALESCE(SUM(quantity), 0) as stock
                FROM warehouse_stock
                WHERE product_id = %s 
                AND variation_id = %s
                AND store_id = %s
                AND warehouse_id = %s
            """, (product_id, variation_id, store_id, warehouse_id))
        else:
            cursor.execute("""
                SELECT COALESCE(SUM(quantity), 0) as stock
                FROM warehouse_stock
                WHERE product_id = %s 
                AND variation_id IS NULL
                AND store_id = %s
                AND warehouse_id = %s
            """, (product_id, store_id, warehouse_id))
        
        result = cursor.fetchone()
        stock = float(result['stock']) if result else 0.0
        
        print(f"📦 Warehouse Stock Query:")
        print(f"   Product ID: {product_id}")
        print(f"   Variation ID: {variation_id or 'None'}")
        print(f"   Store ID: {store_id}")
        print(f"   Warehouse ID: {warehouse_id}")
        print(f"   ✅ Stock: {stock} (0 if no GRN approved)")
        
        return jsonify({
            "success": True,
            "stock": stock,
            "product_id": product_id,
            "variation_id": variation_id,
            "store_id": store_id,
            "warehouse_id": warehouse_id
        }), 200
        
    except Exception as e:
        print(f"❌ Error getting warehouse stock: {str(e)}")
        traceback.print_exc()
        return jsonify({"error": str(e)}), 500
    
    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()


@purchase_bp.route('/get_order_id_by_supplier_and_product', methods=['GET'])
@jwt_required()
@role_required('admin', 'manager')
def get_order_id_by_supplier_and_product():
    """
    Get purchase order IDs by supplier name and product name (with optional variation)
    
    ✅ CORRECTED: Searches by grn_status='fully_received' instead of status='Received'
    
    Query Parameters:
    - supplier_name (required): Name of the supplier
    - product_name (required): Name of the product
    - variation_name (optional): Name of the product variation
    
    Returns:
    {
        "order_ids": [123, 456],
        "count": 2
    }
    """
    
    try:
        supplier_name = request.args.get('supplier_name', '').strip()
        product_name = request.args.get('product_name', '').strip()
        variation_name = request.args.get('variation_name', '').strip()
        
        if not supplier_name:
            return jsonify({'error': 'supplier_name is required'}), 400
        
        if not product_name:
            return jsonify({'error': 'product_name is required'}), 400
        
        conn = get_db_connection()
        if not conn:
            return jsonify({'error': 'Database connection failed'}), 500
        
        cursor = conn.cursor(dictionary=True)
        
        print(f"🔍 Searching for orders:")
        print(f"   Supplier: {supplier_name}")
        print(f"   Product: {product_name}")
        print(f"   Variation: {variation_name or 'None'}")
        
        # ✅ CORRECTED: Search by grn_status instead of status
        if variation_name:
            # Search with variation
            query = """
                SELECT DISTINCT po.order_id
                FROM purchase_orders po
                INNER JOIN suppliers s ON po.supplier_id = s.id
                INNER JOIN order_items oi ON po.order_id = oi.order_id
                INNER JOIN products p ON oi.product_id = p.id
                INNER JOIN product_variations pv ON oi.variation_id = pv.id
                WHERE LOWER(s.supplier_name) = LOWER(%s)
                  AND LOWER(p.product_name) = LOWER(%s)
                  AND LOWER(pv.variation_name) = LOWER(%s)
                  AND po.grn_status = 'fully_received'
                ORDER BY po.order_id DESC
            """
            cursor.execute(query, (supplier_name, product_name, variation_name))
        else:
            # Search without variation (single product)
            query = """
                SELECT DISTINCT po.order_id
                FROM purchase_orders po
                INNER JOIN suppliers s ON po.supplier_id = s.id
                INNER JOIN order_items oi ON po.order_id = oi.order_id
                INNER JOIN products p ON oi.product_id = p.id
                WHERE LOWER(s.supplier_name) = LOWER(%s)
                  AND LOWER(p.product_name) = LOWER(%s)
                  AND oi.variation_id IS NULL
                  AND po.grn_status = 'fully_received'
                ORDER BY po.order_id DESC
            """
            cursor.execute(query, (supplier_name, product_name))
        
        results = cursor.fetchall()
        order_ids = [row['order_id'] for row in results]
        
        print(f"✅ Found {len(order_ids)} order(s) with GRN fully received: {order_ids}")
        
        cursor.close()
        conn.close()
        
        return jsonify({
            'order_ids': order_ids,
            'count': len(order_ids)
        }), 200
        
    except Exception as e:
        print(f"❌ Error in get_order_id_by_supplier_and_product: {str(e)}")
        traceback.print_exc()
        return jsonify({'error': str(e)}), 500
    
    
@purchase_bp.route('/get_purchase_to_return/<int:order_id>', methods=['GET'])
@jwt_required()
@role_required('admin', 'manager', 'cashier')
def get_purchase_to_return(order_id):
    """
    Get single purchase order with items
    ✅ FIXED: max_returnable_quantity = GRN received quantity - already returned quantity
    
    Logic:
    1. Get actual received quantity from THIS purchase order's GRN
    2. Get already returned quantity from previous returns
    3. Max returnable = received - already_returned
    
    Example:
    - Purchase Order: 10 units
    - GRN Received: 9 units (1 rejected)
    - Already Returned: 2 units
    - Max Returnable: 7 units ✅
    """
    conn = None
    cursor = None

    try:
        conn = get_db_connection()
        if not conn:
            return jsonify({'error': 'Database connection failed'}), 500

        cursor = conn.cursor(dictionary=True, buffered=True)

        # ===========================
        # Step 1: Get Purchase Order Header
        # ===========================
        cursor.execute("""
            SELECT 
                po.order_id, po.order_tax, po.discount, po.status, 
                po.grand_total, po.payment_status, po.payment_type,
                po.paid_amount, po.due_amount, po.note,
                po.created_on as order_date,
                po.warehouse_id, po.store_id,
                s.id as supplier_id, s.supplier_name, s.supplier_code,
                s.supplier_contact, s.supplier_email, 
                s.supplier_optional_contact, s.supplier_address,
                w.warehouse_name,
                st.store_name
            FROM purchase_orders po
            LEFT JOIN suppliers s ON po.supplier_id = s.id
            LEFT JOIN warehouses w ON po.warehouse_id = w.id
            LEFT JOIN stores st ON po.store_id = st.id
            WHERE po.order_id = %s
        """, (order_id,))

        order_data = cursor.fetchone()

        if not order_data:
            return jsonify({'error': 'Purchase order not found'}), 404

        # ===========================
        # Step 2: Get Order Items with CORRECT Stock Calculation
        # ✅ Max Returnable = GRN Received - Already Returned
        # ===========================
        cursor.execute("""
            SELECT 
                oi.item_id,
                oi.product_id,
                oi.variation_id,
                oi.quantity as purchase_order_quantity,
                oi.unit_price,
                oi.net_unit_cost,
                oi.purchase_unit,
                oi.discount_type,
                oi.product_discount,
                oi.tax_type,
                oi.product_tax,
                oi.discount,
                oi.tax,
                oi.subtotal,
                oi.expiration_date,
                p.product_name,
                p.sku,
                pv.variation_name,
                pv.variation_sku,
                
                -- ✅ Get ACTUAL RECEIVED quantity from THIS purchase order's GRN
                COALESCE(
                    (SELECT SUM(gi.received_quantity)
                     FROM grn_items gi
                     INNER JOIN grn g ON gi.grn_id = g.grn_id
                     WHERE g.purchase_order_id = %s
                     AND g.status = 'completed'
                     AND gi.product_id = oi.product_id
                     AND (gi.variation_id IS NULL AND oi.variation_id IS NULL 
                          OR gi.variation_id = oi.variation_id)
                    ), 0
                ) as actual_received_quantity,
                
                -- ✅ Get ALREADY RETURNED quantity from THIS purchase order
                COALESCE(
                    (SELECT SUM(pri.quantity)
                     FROM purchase_return_items pri
                     INNER JOIN purchase_returns pr ON pri.return_id = pr.return_id
                     WHERE pr.order_id = %s
                     AND pr.return_status IN ('Pending', 'Received')
                     AND pri.product_id = oi.product_id
                     AND (pri.variation_id IS NULL AND oi.variation_id IS NULL 
                          OR pri.variation_id = oi.variation_id)
                    ), 0
                ) as already_returned_quantity,
                
                -- ✅ Current warehouse stock (for reference)
                COALESCE(
                    (SELECT SUM(ws.quantity)
                     FROM warehouse_stock ws
                     WHERE ws.store_id = %s
                     AND ws.warehouse_id = %s
                     AND ws.product_id = oi.product_id
                     AND (ws.variation_id IS NULL AND oi.variation_id IS NULL 
                          OR ws.variation_id = oi.variation_id)
                    ), 0
                ) as current_warehouse_stock,
                
                -- ✅ MAX RETURNABLE = Received from GRN - Already Returned
                GREATEST(0, 
                    COALESCE(
                        (SELECT SUM(gi.received_quantity)
                         FROM grn_items gi
                         INNER JOIN grn g ON gi.grn_id = g.grn_id
                         WHERE g.purchase_order_id = %s
                         AND g.status = 'completed'
                         AND gi.product_id = oi.product_id
                         AND (gi.variation_id IS NULL AND oi.variation_id IS NULL 
                              OR gi.variation_id = oi.variation_id)
                        ), 0
                    ) - 
                    COALESCE(
                        (SELECT SUM(pri.quantity)
                         FROM purchase_return_items pri
                         INNER JOIN purchase_returns pr ON pri.return_id = pr.return_id
                         WHERE pr.order_id = %s
                         AND pr.return_status IN ('Pending', 'Received')
                         AND pri.product_id = oi.product_id
                         AND (pri.variation_id IS NULL AND oi.variation_id IS NULL 
                              OR pri.variation_id = oi.variation_id)
                        ), 0
                    )
                ) as max_returnable_quantity
                
            FROM order_items oi
            LEFT JOIN products p ON oi.product_id = p.id
            LEFT JOIN product_variations pv ON oi.variation_id = pv.id
            WHERE oi.order_id = %s
        """, (
            order_id,                    # For GRN received quantity
            order_id,                    # For already returned quantity
            order_data['store_id'],      # For warehouse_stock
            order_data['warehouse_id'],  # For warehouse_stock
            order_id,                    # For max_returnable calculation (GRN)
            order_id,                    # For max_returnable calculation (returns)
            order_id                     # For WHERE clause
        ))

        items = cursor.fetchall()

        # ===========================
        # Step 3: Build Response with All Quantities
        # ===========================
        formatted_items = []
        for item in items:
            purchase_order_qty = float(item['purchase_order_quantity'])
            actual_received_qty = float(item['actual_received_quantity'])
            already_returned_qty = float(item['already_returned_quantity'])
            current_warehouse_stock = float(item['current_warehouse_stock'])
            max_returnable = float(item['max_returnable_quantity'])
            
            formatted_item = {
                'item_id': item['item_id'],
                'product_id': item['product_id'],
                'variation_id': item['variation_id'],
                'product_name': item['product_name'],
                'sku': item['sku'],
                'variation_name': item['variation_name'],
                'variation_sku': item['variation_sku'],
                
                # ✅ All quantity information
                'purchase_order_quantity': purchase_order_qty,      # Original order (e.g., 10)
                'actual_received_quantity': actual_received_qty,    # GRN received (e.g., 9)
                'already_returned_quantity': already_returned_qty,  # Already returned (e.g., 2)
                'current_warehouse_stock': current_warehouse_stock, # Current stock
                'max_returnable_quantity': max_returnable,         # Max returnable (e.g., 7 = 9 - 2) ✅
                
                # Pricing information
                'unit_price': float(item['unit_price']),
                'net_unit_cost': float(item['net_unit_cost']),
                'purchase_unit': item['purchase_unit'],
                'discount_type': item['discount_type'],
                'product_discount': float(item['product_discount']),
                'tax_type': item['tax_type'],
                'product_tax': float(item['product_tax']),
                'discount': float(item['discount']),
                'tax': float(item['tax']),
                'subtotal': float(item['subtotal']),
                'expiration_date': item['expiration_date'].strftime('%Y-%m-%d') if item['expiration_date'] else None
            }
            formatted_items.append(formatted_item)

        # ===========================
        # Step 4: Build Final Response
        # ===========================
        response = {
            'order_id': order_data['order_id'],
            'order_date': order_data['order_date'].strftime('%Y-%m-%d %H:%M:%S') if order_data['order_date'] else None,
            'order_tax': float(order_data['order_tax']),
            'discount': float(order_data['discount']),
            'status': order_data['status'],
            'grand_total': float(order_data['grand_total']),
            'payment_status': order_data['payment_status'],
            'payment_type': order_data['payment_type'],
            'paid_amount': float(order_data['paid_amount']) if order_data['paid_amount'] else 0.00,
            'due_amount': float(order_data['due_amount']) if order_data['due_amount'] else 0.00,
            'note': order_data['note'],
            'supplier': {
                'id': order_data['supplier_id'],
                'supplier_name': order_data['supplier_name'],
                'supplier_code': order_data['supplier_code'],
                'supplier_contact': order_data['supplier_contact'],
                'supplier_email': order_data['supplier_email'],
                'supplier_optional_contact': order_data['supplier_optional_contact'],
                'supplier_address': order_data['supplier_address']
            },
            'location': {
                'warehouse_id': order_data['warehouse_id'],
                'warehouse_name': order_data['warehouse_name'],
                'store_id': order_data['store_id'],
                'store_name': order_data['store_name']
            },
            'items': formatted_items
        }

        return jsonify(response), 200

    except Exception as e:
        print(f"❌ Error fetching purchase order: {e}")
        traceback.print_exc()
        return jsonify({'error': str(e)}), 500

    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()