from flask import Blueprint, request, jsonify
from flask_jwt_extended import jwt_required, get_jwt_identity
from db.db import get_db_connection
from config.auth import role_required
from datetime import datetime
import json
import mysql.connector
import traceback

adjustment_bp = Blueprint('adjustment', __name__)


def safe_float(value, default=0.0):
    """Safely convert to float"""
    try:
        return float(value) if value not in (None, '', 'null') else default
    except (ValueError, TypeError):
        return default


def safe_int(value, default=0):
    """Safely convert to int"""
    try:
        return int(value) if value not in (None, '', 'null') else default
    except (ValueError, TypeError):
        return default


# ============================================
# ✅ SIMPLIFIED PRODUCT SEARCH FOR ADJUSTMENTS
# ============================================
@adjustment_bp.route('/adjustment_product_search', methods=['GET'])
@jwt_required()
@role_required('admin', 'cashier')
def adjustment_product_search():
    """
    Simplified product search for adjustments
    Returns distinct product/variation/batch with current stock
    No grouping - direct warehouse_stock results
    """
    warehouse_id = request.args.get('warehouse_id', '').strip()
    store_id = request.args.get('store_id', '').strip()
    query = request.args.get('query', '').strip()
    
    if not warehouse_id or not store_id:
        return jsonify({
            "status": "error", 
            "message": "warehouse_id and store_id are required"
        }), 400
    
    conn = get_db_connection()
    if conn is None:
        return jsonify({'error': 'Database connection failed'}), 500
        
    cursor = conn.cursor(dictionary=True)
    
    try:
        search_condition = ""
        search_params = [warehouse_id, store_id]
        
        if query:
            search_condition = """
                AND (
                    p.product_name LIKE %s 
                    OR p.sku LIKE %s
                    OR pv.variation_name LIKE %s
                    OR pv.variation_sku LIKE %s
                    OR pb.batch_number LIKE %s
                )
            """
            search_term = f"%{query}%"
            search_params.extend([search_term] * 5)
        
        query_sql = f"""
            SELECT 
                ws.id as stock_id,
                ws.product_id,
                ws.variation_id,
                ws.batch_id,
                ws.quantity as stock_quantity,
                
                -- Product Details
                p.product_name,
                p.sku,
                p.product_type,
                
                -- Variation Details
                pv.variation_name,
                pv.variation_type,
                pv.variation_sku,
                
                -- Batch Details
                pb.batch_number,
                pb.cost,
                pb.price,
                pb.expiration_date,
                
                -- GRN & Supplier
                g.grn_code,
                s.supplier_name,
                s.supplier_code
                
            FROM warehouse_stock ws
            INNER JOIN products p ON ws.product_id = p.id
            LEFT JOIN product_variations pv ON ws.variation_id = pv.id
            LEFT JOIN product_batches pb ON ws.batch_id = pb.batch_id
            LEFT JOIN grn g ON pb.grn_id = g.grn_id
            LEFT JOIN suppliers s ON g.supplier_id = s.id
            
            WHERE ws.warehouse_id = %s
            AND ws.store_id = %s
            AND ws.quantity > 0
            {search_condition}
            
            ORDER BY p.product_name, pv.variation_name, pb.expiration_date
            LIMIT 50
        """
        
        cursor.execute(query_sql, search_params)
        results = cursor.fetchall()
        
        # Format results
        formatted_results = []
        for item in results:
            # Build product name
            if item['variation_id']:
                variation_type = f" ({item['variation_type']})" if item['variation_type'] else ''
                product_name = f"{item['product_name']} - {item['variation_name']}{variation_type}"
                sku = item['variation_sku']
            else:
                product_name = item['product_name']
                sku = item['sku']
            
            # Format display
            cost = float(item['cost'] or 0)
            price = float(item['price'] or 0)
            stock = float(item['stock_quantity'] or 0)
            exp_date = item['expiration_date'].strftime('%Y-%m-%d') if item['expiration_date'] else '-'
            batch_num = item['batch_number'] or 'No Batch'
            supplier = item['supplier_name'] or 'Unknown'
            grn_code = item['grn_code'] or '-'
            
            display_name = (
                f"{product_name} | "
                f"Batch: {batch_num} | "
                f"GRN: {grn_code} | "
                f"Supplier: {supplier} | "
                f"Cost {cost:.2f} - Price {price:.2f} | "
                f"Stock {stock:.2f} | "
                f"Exp: {exp_date}"
            )
            
            formatted_results.append({
                "stock_id": item['stock_id'],
                "product_id": item['product_id'],
                "variation_id": item['variation_id'],
                "batch_id": item['batch_id'],
                "product_name": product_name,
                "sku": sku,
                "display_name": display_name,
                "product_cost": cost,
                "product_price": price,
                "product_quantity": stock,
                "batch_number": batch_num,
                "grn_code": grn_code,
                "supplier_name": supplier,
                "expiration_date": exp_date
            })
        
        return jsonify(formatted_results), 200
        
    except mysql.connector.Error as err:
        print(f"❌ Database Error: {err}")
        traceback.print_exc()
        return jsonify({
            "status": "error", 
            "message": f"Database error: {str(err)}"
        }), 500
    
    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()


# ============================================
# ✅ CREATE STOCK ADJUSTMENT
# ============================================
@adjustment_bp.route('/create_adjustment', methods=['POST'])
@jwt_required()
@role_required('admin', 'cashier')
def create_adjustment():
    """
    Create Stock Adjustment with Batch Tracking
    
    Expected JSON:
    {
        "warehouse_id": 1,
        "store_id": 2,
        "adjustment_date": "2026-01-30",
        "reason": "Stock count correction",
        "note": "Monthly inventory check",
        "items": [
            {
                "product_id": 5,
                "variation_id": null,
                "batch_id": 10,
                "quantity": 10,
                "adjustment_type": "addition",
                "note": "Found extra stock"
            }
        ]
    }
    """
    try:
        data = request.get_json(force=True)
        if not data:
            return jsonify({'error': 'No data provided'}), 400
    except Exception as e:
        return jsonify({'error': 'Invalid JSON data', 'message': str(e)}), 400

    print("=" * 80)
    print("📝 STOCK ADJUSTMENT CREATION")
    print("=" * 80)
    print(json.dumps(data, indent=2, default=str))
    print("=" * 80)

    # Extract data
    warehouse_id = safe_int(data.get('warehouse_id'))
    store_id = safe_int(data.get('store_id'))
    adjustment_date = data.get('adjustment_date')
    reason = data.get('reason', '')
    note = data.get('note', '')
    items = data.get('items', [])

    # Validation
    if not warehouse_id:
        return jsonify({'error': 'Warehouse is required'}), 400

    if not store_id:
        return jsonify({'error': 'Store is required'}), 400

    if not items or len(items) == 0:
        return jsonify({'error': 'At least one item is required'}), 400

    conn = get_db_connection()
    if not conn:
        return jsonify({'error': 'Database connection failed'}), 500

    cursor = conn.cursor(dictionary=True, buffered=True)

    try:
        conn.start_transaction()
        print("\n🔄 Transaction started")

        # Generate adjustment code
        adjustment_code = f"ADJ-{datetime.now().strftime('%Y%m%d%H%M%S')}"

        # Get current user
        current_user = get_jwt_identity()
        print(f"  User: {current_user}")

        # Determine adjustment type (overall type for the adjustment)
        types = set(item.get('adjustment_type', 'addition').lower() for item in items)
        adjustment_type = 'addition' if types == {'addition'} else 'subtraction' if types == {'subtraction'} else 'addition'

        # Create adjustment record
        print(f"\n📝 Creating adjustment: {adjustment_code}")
        cursor.execute("""
            INSERT INTO stock_adjustments (
                adjustment_code, warehouse_id, store_id, 
                adjustment_type, reason, note, total_items, 
                created_by, created_at
            ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, NOW())
        """, (adjustment_code, warehouse_id, store_id, adjustment_type, 
              reason, note, len(items), current_user))

        adjustment_id = cursor.lastrowid
        print(f"✅ Adjustment created with ID: {adjustment_id}")

        # Process items
        print(f"\n📝 Processing {len(items)} items...")
        
        items_processed = 0
        
        for idx, item in enumerate(items, 1):
            try:
                product_id = safe_int(item.get('product_id'))
                variation_id = safe_int(item.get('variation_id')) if item.get('variation_id') else None
                batch_id = safe_int(item.get('batch_id')) if item.get('batch_id') else None
                quantity = safe_float(item.get('quantity', 0))
                item_adjustment_type = item.get('adjustment_type', 'addition').lower()
                item_note = item.get('note', '')

                if quantity <= 0:
                    print(f"  ⚠️  Skipping item {idx}: quantity is 0 or negative")
                    continue

                print(f"\n  📦 Processing item {idx}:")
                print(f"    Product ID: {product_id}")
                print(f"    Variation ID: {variation_id}")
                print(f"    Batch ID: {batch_id}")
                print(f"    Quantity: {quantity}")
                print(f"    Type: {item_adjustment_type}")

                # Validate product exists
                cursor.execute("SELECT id FROM products WHERE id = %s", (product_id,))
                if not cursor.fetchone():
                    raise ValueError(f"Product ID {product_id} does not exist")

                # Validate variation if provided
                if variation_id:
                    cursor.execute("""
                        SELECT id FROM product_variations 
                        WHERE id = %s AND product_id = %s
                    """, (variation_id, product_id))
                    if not cursor.fetchone():
                        raise ValueError(f"Variation ID {variation_id} does not exist for product {product_id}")

                # Validate batch if provided
                if batch_id:
                    cursor.execute("""
                        SELECT batch_id, remaining_quantity
                        FROM product_batches
                        WHERE batch_id = %s AND product_id = %s
                        AND (variation_id IS NULL AND %s IS NULL OR variation_id = %s)
                    """, (batch_id, product_id, variation_id, variation_id))
                    
                    batch = cursor.fetchone()
                    if not batch:
                        raise ValueError(f"Batch ID {batch_id} does not exist for this product")
                    
                    # For subtraction, check if batch has enough quantity
                    if item_adjustment_type == 'subtraction':
                        batch_qty = safe_float(batch['remaining_quantity'])
                        if batch_qty < quantity:
                            raise ValueError(
                                f"Batch {batch_id} has insufficient quantity. " +
                                f"Available: {batch_qty}, Requested: {quantity}"
                            )

                # Insert adjustment item
                cursor.execute("""
                    INSERT INTO stock_adjustment_items (
                        adjustment_id, product_id, variation_id, batch_id,
                        quantity, adjustment_type, note, created_at
                    ) VALUES (%s, %s, %s, %s, %s, %s, %s, NOW())
                """, (adjustment_id, product_id, variation_id, batch_id, 
                      quantity, item_adjustment_type, item_note))

                item_id = cursor.lastrowid
                print(f"    ✅ Adjustment item created with ID: {item_id}")

                # Update product_batches.remaining_quantity
                if batch_id:
                    if item_adjustment_type == 'addition':
                        cursor.execute("""
                            UPDATE product_batches
                            SET remaining_quantity = remaining_quantity + %s
                            WHERE batch_id = %s
                        """, (quantity, batch_id))
                        print(f"    ✅ Updated batch {batch_id}: +{quantity}")
                    
                    elif item_adjustment_type == 'subtraction':
                        cursor.execute("""
                            UPDATE product_batches
                            SET remaining_quantity = remaining_quantity - %s
                            WHERE batch_id = %s
                        """, (quantity, batch_id))
                        print(f"    ✅ Updated batch {batch_id}: -{quantity}")

                # Update warehouse_stock
                cursor.execute("""
                    SELECT id, quantity FROM warehouse_stock
                    WHERE store_id = %s 
                    AND warehouse_id = %s
                    AND product_id = %s
                    AND (variation_id IS NULL AND %s IS NULL OR variation_id = %s)
                    AND (batch_id IS NULL AND %s IS NULL OR batch_id = %s)
                """, (store_id, warehouse_id, product_id, 
                      variation_id, variation_id, batch_id, batch_id))
                
                existing_stock = cursor.fetchone()

                if item_adjustment_type == 'addition':
                    if existing_stock:
                        current_qty = safe_float(existing_stock['quantity'])
                        new_qty = current_qty + quantity
                        
                        cursor.execute("""
                            UPDATE warehouse_stock
                            SET quantity = %s, updated_at = NOW()
                            WHERE id = %s
                        """, (new_qty, existing_stock['id']))
                        
                        print(f"    ✅ Updated warehouse stock: {current_qty} + {quantity} = {new_qty}")
                    else:
                        cursor.execute("""
                            INSERT INTO warehouse_stock (
                                store_id, warehouse_id, product_id, variation_id,
                                batch_id, quantity, updated_at
                            ) VALUES (%s, %s, %s, %s, %s, %s, NOW())
                        """, (store_id, warehouse_id, product_id, variation_id, 
                              batch_id, quantity))
                        
                        print(f"    ✅ Created new warehouse stock entry: {quantity}")
                
                elif item_adjustment_type == 'subtraction':
                    if existing_stock:
                        current_qty = safe_float(existing_stock['quantity'])
                        
                        if current_qty < quantity:
                            raise ValueError(
                                f"Insufficient warehouse stock for product ID {product_id}" +
                                (f" variation {variation_id}" if variation_id else "") +
                                (f" batch {batch_id}" if batch_id else "") +
                                f". Available: {current_qty}, Requested: {quantity}"
                            )
                        
                        new_qty = current_qty - quantity
                        
                        cursor.execute("""
                            UPDATE warehouse_stock
                            SET quantity = %s, updated_at = NOW()
                            WHERE id = %s
                        """, (new_qty, existing_stock['id']))
                        
                        print(f"    ✅ Updated warehouse stock: {current_qty} - {quantity} = {new_qty}")
                    else:
                        raise ValueError(
                            f"Cannot subtract from non-existent warehouse stock for product ID {product_id}" +
                            (f" variation {variation_id}" if variation_id else "") +
                            (f" batch {batch_id}" if batch_id else "")
                        )

                items_processed += 1

            except Exception as item_error:
                print(f"  ❌ Error processing item {idx}: {item_error}")
                raise item_error

        # Commit transaction
        conn.commit()
        print(f"\n✅ Transaction committed successfully!")
        print(f"✅ Items processed: {items_processed}/{len(items)}")

        print("\n" + "=" * 80)
        print(f"✅ SUCCESS: Adjustment {adjustment_code} created")
        print("=" * 80)

        return jsonify({
            'success': True,
            'message': 'Adjustment created successfully',
            'adjustment_id': adjustment_id,
            'adjustment_code': adjustment_code,
            'items_processed': items_processed,
            'adjustment_type': adjustment_type
        }), 201

    except ValueError as err:
        conn.rollback()
        print(f"\n❌ Validation Error: {err}")
        return jsonify({'error': str(err)}), 400

    except mysql.connector.IntegrityError as err:
        conn.rollback()
        print(f"\n❌ Database Integrity Error: {err}")
        
        error_msg = str(err)
        if 'foreign key constraint' in error_msg.lower():
            return jsonify({'error': 'Invalid reference - check warehouse, product, or batch IDs'}), 400
        
        return jsonify({'error': f'Database integrity error: {error_msg}'}), 500

    except mysql.connector.Error as err:
        conn.rollback()
        print(f"\n❌ Database Error: {err}")
        traceback.print_exc()
        return jsonify({'error': f'Database error: {str(err)}'}), 500

    except Exception as e:
        conn.rollback()
        print(f"\n❌ Unexpected Error: {e}")
        traceback.print_exc()
        return jsonify({'error': f'Server error: {str(e)}'}), 500

    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()
        print("\n🔒 Database connection closed\n")


# ============================================
# ✅ GET ALL ADJUSTMENTS WITH FILTERING
# ============================================
@adjustment_bp.route('/get_adjustments', methods=['GET'])
@jwt_required()
@role_required('admin', 'cashier')
def get_adjustments():
    """
    Get all stock adjustments with pagination and filtering
    """
    conn = None
    cursor = None
    
    try:
        page = int(request.args.get('page', 1))
        per_page = int(request.args.get('per_page', 20))
        warehouse_id = request.args.get('warehouse_id')
        store_id = request.args.get('store_id')
        adjustment_type = request.args.get('adjustment_type')
        search = request.args.get('search', '').strip()
        
        offset = (page - 1) * per_page

        conn = get_db_connection()
        if not conn:
            return jsonify({'error': 'Database connection failed'}), 500

        cursor = conn.cursor(dictionary=True, buffered=True)

        # Build WHERE clause
        where_clauses = []
        params = []

        if warehouse_id:
            where_clauses.append("sa.warehouse_id = %s")
            params.append(warehouse_id)

        if store_id:
            where_clauses.append("sa.store_id = %s")
            params.append(store_id)

        if adjustment_type:
            where_clauses.append("sa.adjustment_type = %s")
            params.append(adjustment_type)

        if search:
            where_clauses.append("sa.adjustment_code LIKE %s")
            params.append(f"%{search}%")

        where_sql = "WHERE " + " AND ".join(where_clauses) if where_clauses else ""

        # Get total count
        cursor.execute(f"""
            SELECT COUNT(*) as total
            FROM stock_adjustments sa
            {where_sql}
        """, params)
        
        total = cursor.fetchone()['total']

        # Get adjustments
        cursor.execute(f"""
            SELECT 
                sa.adjustment_id,
                sa.adjustment_code,
                sa.warehouse_id,
                w.warehouse_name,
                sa.store_id,
                st.store_name,
                sa.adjustment_type,
                sa.reason,
                sa.note,
                sa.total_items,
                sa.created_by,
                u.name as created_by_name,
                sa.created_at,
                sa.updated_at
            FROM stock_adjustments sa
            LEFT JOIN warehouses w ON sa.warehouse_id = w.id
            LEFT JOIN stores st ON sa.store_id = st.id
            LEFT JOIN users u ON sa.created_by = u.id
            {where_sql}
            ORDER BY sa.created_at DESC
            LIMIT %s OFFSET %s
        """, params + [per_page, offset])

        adjustments = cursor.fetchall()

        # Format dates
        for adjustment in adjustments:
            adjustment['created_at'] = adjustment['created_at'].strftime('%Y-%m-%d %H:%M:%S') if adjustment['created_at'] else None
            adjustment['updated_at'] = adjustment['updated_at'].strftime('%Y-%m-%d %H:%M:%S') if adjustment['updated_at'] else None

        return jsonify({
            'success': True,
            'adjustments': adjustments,
            'pagination': {
                'page': page,
                'per_page': per_page,
                'total': total,
                'pages': (total + per_page - 1) // per_page
            }
        }), 200

    except Exception as e:
        print(f"❌ Error fetching adjustments: {e}")
        traceback.print_exc()
        return jsonify({'error': str(e)}), 500

    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()


# ============================================
# ✅ GET SINGLE ADJUSTMENT WITH ITEMS
# ============================================
@adjustment_bp.route('/get_adjustment/<int:adjustment_id>', methods=['GET'])
@jwt_required()
@role_required('admin', 'cashier')
def get_adjustment(adjustment_id):
    """
    Get single adjustment with items (includes batch and GRN info)
    """
    conn = None
    cursor = None
    
    try:
        conn = get_db_connection()
        if not conn:
            return jsonify({'error': 'Database connection failed'}), 500

        cursor = conn.cursor(dictionary=True, buffered=True)

        # Get adjustment header
        cursor.execute("""
            SELECT 
                sa.adjustment_id,
                sa.adjustment_code,
                sa.warehouse_id,
                w.warehouse_name,
                sa.store_id,
                st.store_name,
                sa.adjustment_type,
                sa.reason,
                sa.note,
                sa.total_items,
                sa.created_by,
                u.name as created_by_name,
                sa.created_at,
                sa.updated_at
            FROM stock_adjustments sa
            LEFT JOIN warehouses w ON sa.warehouse_id = w.id
            LEFT JOIN stores st ON sa.store_id = st.id
            LEFT JOIN users u ON sa.created_by = u.id
            WHERE sa.adjustment_id = %s
        """, (adjustment_id,))

        adjustment = cursor.fetchone()

        if not adjustment:
            return jsonify({'error': 'Adjustment not found'}), 404

        # Get items with batch, GRN, and supplier information
        cursor.execute("""
            SELECT 
                sai.item_id,
                sai.product_id,
                p.product_name,
                p.sku,
                sai.variation_id,
                pv.variation_name,
                pv.variation_type,
                pv.variation_sku,
                sai.batch_id,
                pb.batch_number,
                pb.cost as batch_cost,
                pb.price as batch_price,
                pb.expiration_date,
                g.grn_code,
                s.supplier_name,
                sai.quantity,
                sai.adjustment_type,
                sai.note,
                sai.created_at
            FROM stock_adjustment_items sai
            LEFT JOIN products p ON sai.product_id = p.id
            LEFT JOIN product_variations pv ON sai.variation_id = pv.id
            LEFT JOIN product_batches pb ON sai.batch_id = pb.batch_id
            LEFT JOIN grn g ON pb.grn_id = g.grn_id
            LEFT JOIN suppliers s ON g.supplier_id = s.id
            WHERE sai.adjustment_id = %s
            ORDER BY sai.item_id
        """, (adjustment_id,))

        items = cursor.fetchall()

        # Format response
        adjustment['created_at'] = adjustment['created_at'].strftime('%Y-%m-%d %H:%M:%S') if adjustment['created_at'] else None
        adjustment['updated_at'] = adjustment['updated_at'].strftime('%Y-%m-%d %H:%M:%S') if adjustment['updated_at'] else None

        for item in items:
            item['quantity'] = float(item['quantity'])
            item['batch_cost'] = float(item['batch_cost'] or 0)
            item['batch_price'] = float(item['batch_price'] or 0)
            item['expiration_date'] = item['expiration_date'].strftime('%Y-%m-%d') if item['expiration_date'] else None
            item['created_at'] = item['created_at'].strftime('%Y-%m-%d %H:%M:%S') if item['created_at'] else None

        adjustment['items'] = items

        return jsonify({
            'success': True,
            'adjustment': adjustment
        }), 200

    except Exception as e:
        print(f"❌ Error fetching adjustment: {e}")
        traceback.print_exc()
        return jsonify({'error': str(e)}), 500

    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()


# ============================================
# ✅ UPDATE STOCK ADJUSTMENT
# ============================================
@adjustment_bp.route('/update_adjustment/<int:adjustment_id>', methods=['PUT'])
@jwt_required()
@role_required('admin', 'cashier')
def update_adjustment(adjustment_id):
    """
    Update Stock Adjustment
    Properly reverses old changes and applies new ones
    """
    conn = None
    cursor = None
    
    try:
        data = request.get_json(force=True)
        if not data:
            return jsonify({'error': 'No data provided'}), 400
    except Exception as e:
        return jsonify({'error': 'Invalid JSON data', 'message': str(e)}), 400

    print("=" * 80)
    print(f"📝 UPDATING STOCK ADJUSTMENT #{adjustment_id}")
    print("=" * 80)
    print(json.dumps(data, indent=2, default=str))
    print("=" * 80)

    # Extract data
    warehouse_id = safe_int(data.get('warehouse_id'))
    store_id = safe_int(data.get('store_id'))
    adjustment_date = data.get('adjustment_date')
    reason = data.get('reason', '')
    note = data.get('note', '')
    items = data.get('items', [])

    # Validation
    if not warehouse_id:
        return jsonify({'error': 'Warehouse is required'}), 400

    if not store_id:
        return jsonify({'error': 'Store is required'}), 400

    if not items or len(items) == 0:
        return jsonify({'error': 'At least one item is required'}), 400

    try:
        conn = get_db_connection()
        if not conn:
            return jsonify({'error': 'Database connection failed'}), 500

        cursor = conn.cursor(dictionary=True, buffered=True)

        conn.start_transaction()
        print("\n🔄 Transaction started")

        # Get existing adjustment details
        cursor.execute("""
            SELECT 
                adjustment_id,
                adjustment_code,
                warehouse_id as old_warehouse_id,
                store_id as old_store_id
            FROM stock_adjustments
            WHERE adjustment_id = %s
        """, (adjustment_id,))

        existing_adjustment = cursor.fetchone()

        if not existing_adjustment:
            return jsonify({'error': 'Adjustment not found'}), 404

        adjustment_code = existing_adjustment['adjustment_code']
        old_warehouse_id = existing_adjustment['old_warehouse_id']
        old_store_id = existing_adjustment['old_store_id']

        print(f"📝 Updating adjustment: {adjustment_code}")

        # STEP 1: Reverse old stock changes
        print("\n🔄 STEP 1: Reversing old stock changes")
        
        cursor.execute("""
            SELECT 
                product_id,
                variation_id,
                batch_id,
                quantity,
                adjustment_type
            FROM stock_adjustment_items
            WHERE adjustment_id = %s
        """, (adjustment_id,))

        old_items = cursor.fetchall()
        print(f"  Found {len(old_items)} old items to reverse")

        for idx, item in enumerate(old_items, 1):
            product_id = item['product_id']
            variation_id = item['variation_id']
            batch_id = item['batch_id']
            quantity = safe_float(item['quantity'])
            adjustment_type = item['adjustment_type']

            print(f"\n  📦 Reversing old item {idx}:")
            print(f"    Product ID: {product_id}, Variation ID: {variation_id}, Batch ID: {batch_id}")
            print(f"    Quantity: {quantity}, Type: {adjustment_type}")

            # Reverse product_batches
            if batch_id:
                if adjustment_type == 'addition':
                    cursor.execute("""
                        UPDATE product_batches
                        SET remaining_quantity = remaining_quantity - %s
                        WHERE batch_id = %s
                    """, (quantity, batch_id))
                    print(f"    ✅ Reversed batch {batch_id}: -{quantity}")
                
                elif adjustment_type == 'subtraction':
                    cursor.execute("""
                        UPDATE product_batches
                        SET remaining_quantity = remaining_quantity + %s
                        WHERE batch_id = %s
                    """, (quantity, batch_id))
                    print(f"    ✅ Reversed batch {batch_id}: +{quantity}")

            # Reverse warehouse_stock
            cursor.execute("""
                SELECT id, quantity FROM warehouse_stock
                WHERE store_id = %s AND warehouse_id = %s
                  AND product_id = %s
                  AND (variation_id IS NULL AND %s IS NULL OR variation_id = %s)
                  AND (batch_id IS NULL AND %s IS NULL OR batch_id = %s)
            """, (old_store_id, old_warehouse_id, product_id, 
                  variation_id, variation_id, batch_id, batch_id))
            
            existing_stock = cursor.fetchone()

            if existing_stock:
                current_qty = safe_float(existing_stock['quantity'])

                if adjustment_type == 'addition':
                    new_qty = current_qty - quantity
                    if new_qty < 0:
                        raise ValueError(
                            f"Cannot reverse adjustment: would result in negative stock. " +
                            f"Current: {current_qty}, To subtract: {quantity}"
                        )
                    
                    cursor.execute("""
                        UPDATE warehouse_stock
                        SET quantity = %s, updated_at = NOW()
                        WHERE id = %s
                    """, (new_qty, existing_stock['id']))
                    print(f"    ✅ Reversed warehouse stock: {current_qty} - {quantity} = {new_qty}")
                
                elif adjustment_type == 'subtraction':
                    new_qty = current_qty + quantity
                    cursor.execute("""
                        UPDATE warehouse_stock
                        SET quantity = %s, updated_at = NOW()
                        WHERE id = %s
                    """, (new_qty, existing_stock['id']))
                    print(f"    ✅ Reversed warehouse stock: {current_qty} + {quantity} = {new_qty}")

        # STEP 2: Delete old adjustment items
        print("\n🗑️  STEP 2: Deleting old adjustment items")
        cursor.execute("""
            DELETE FROM stock_adjustment_items
            WHERE adjustment_id = %s
        """, (adjustment_id,))
        print(f"  ✅ Deleted {len(old_items)} old items")

        # STEP 3: Update adjustment header
        print("\n📝 STEP 3: Updating adjustment header")
        
        types = set(item.get('adjustment_type', 'addition').lower() for item in items)
        adjustment_type = 'addition' if types == {'addition'} else 'subtraction' if types == {'subtraction'} else 'addition'

        cursor.execute("""
            UPDATE stock_adjustments
            SET warehouse_id = %s,
                store_id = %s,
                adjustment_type = %s,
                reason = %s,
                note = %s,
                total_items = %s,
                updated_at = NOW()
            WHERE adjustment_id = %s
        """, (warehouse_id, store_id, adjustment_type, reason, note, len(items), adjustment_id))
        
        print(f"  ✅ Updated adjustment header")

        # STEP 4: Insert new adjustment items and update stock
        print("\n📝 STEP 4: Inserting new adjustment items")
        
        items_processed = 0
        
        for idx, item in enumerate(items, 1):
            try:
                product_id = safe_int(item.get('product_id'))
                variation_id = safe_int(item.get('variation_id')) if item.get('variation_id') else None
                batch_id = safe_int(item.get('batch_id')) if item.get('batch_id') else None
                quantity = safe_float(item.get('quantity', 0))
                item_adjustment_type = item.get('adjustment_type', 'addition').lower()
                item_note = item.get('note', '')

                if quantity <= 0:
                    print(f"  ⚠️  Skipping item {idx}: quantity is 0 or negative")
                    continue

                print(f"\n  📦 Processing new item {idx}:")
                print(f"    Product ID: {product_id}, Variation ID: {variation_id}, Batch ID: {batch_id}")
                print(f"    Quantity: {quantity}, Type: {item_adjustment_type}")

                # Insert adjustment item
                cursor.execute("""
                    INSERT INTO stock_adjustment_items (
                        adjustment_id, product_id, variation_id, batch_id,
                        quantity, adjustment_type, note, created_at
                    ) VALUES (%s, %s, %s, %s, %s, %s, %s, NOW())
                """, (adjustment_id, product_id, variation_id, batch_id,
                      quantity, item_adjustment_type, item_note))

                item_id = cursor.lastrowid
                print(f"    ✅ Adjustment item created with ID: {item_id}")

                # Update product_batches
                if batch_id:
                    if item_adjustment_type == 'addition':
                        cursor.execute("""
                            UPDATE product_batches
                            SET remaining_quantity = remaining_quantity + %s
                            WHERE batch_id = %s
                        """, (quantity, batch_id))
                        print(f"    ✅ Updated batch {batch_id}: +{quantity}")
                    
                    elif item_adjustment_type == 'subtraction':
                        cursor.execute("""
                            UPDATE product_batches
                            SET remaining_quantity = remaining_quantity - %s
                            WHERE batch_id = %s
                        """, (quantity, batch_id))
                        print(f"    ✅ Updated batch {batch_id}: -{quantity}")

                # Update warehouse_stock
                cursor.execute("""
                    SELECT id, quantity FROM warehouse_stock
                    WHERE store_id = %s AND warehouse_id = %s
                      AND product_id = %s
                      AND (variation_id IS NULL AND %s IS NULL OR variation_id = %s)
                      AND (batch_id IS NULL AND %s IS NULL OR batch_id = %s)
                """, (store_id, warehouse_id, product_id, 
                      variation_id, variation_id, batch_id, batch_id))
                
                existing_stock = cursor.fetchone()

                if item_adjustment_type == 'addition':
                    if existing_stock:
                        current_qty = safe_float(existing_stock['quantity'])
                        new_qty = current_qty + quantity
                        
                        cursor.execute("""
                            UPDATE warehouse_stock
                            SET quantity = %s, updated_at = NOW()
                            WHERE id = %s
                        """, (new_qty, existing_stock['id']))
                        
                        print(f"    ✅ Updated warehouse stock: {current_qty} + {quantity} = {new_qty}")
                    else:
                        cursor.execute("""
                            INSERT INTO warehouse_stock (
                                store_id, warehouse_id, product_id, variation_id,
                                batch_id, quantity, updated_at
                            ) VALUES (%s, %s, %s, %s, %s, %s, NOW())
                        """, (store_id, warehouse_id, product_id, variation_id, 
                              batch_id, quantity))
                        
                        print(f"    ✅ Created new warehouse stock: {quantity}")
                
                elif item_adjustment_type == 'subtraction':
                    if existing_stock:
                        current_qty = safe_float(existing_stock['quantity'])
                        
                        if current_qty < quantity:
                            raise ValueError(
                                f"Insufficient warehouse stock. " +
                                f"Available: {current_qty}, Requested: {quantity}"
                            )
                        
                        new_qty = current_qty - quantity
                        
                        cursor.execute("""
                            UPDATE warehouse_stock
                            SET quantity = %s, updated_at = NOW()
                            WHERE id = %s
                        """, (new_qty, existing_stock['id']))
                        
                        print(f"    ✅ Updated warehouse stock: {current_qty} - {quantity} = {new_qty}")
                    else:
                        raise ValueError("Cannot subtract from non-existent warehouse stock")

                items_processed += 1

            except Exception as item_error:
                print(f"  ❌ Error processing item {idx}: {item_error}")
                raise item_error

        # Commit transaction
        conn.commit()
        print(f"\n✅ Transaction committed successfully!")
        print(f"✅ Items processed: {items_processed}/{len(items)}")

        print("\n" + "=" * 80)
        print(f"✅ SUCCESS: Adjustment {adjustment_code} updated")
        print("=" * 80)

        return jsonify({
            'success': True,
            'message': 'Adjustment updated successfully',
            'adjustment_id': adjustment_id,
            'adjustment_code': adjustment_code,
            'items_processed': items_processed,
            'adjustment_type': adjustment_type
        }), 200

    except ValueError as err:
        conn.rollback()
        print(f"\n❌ Validation Error: {err}")
        return jsonify({'error': str(err)}), 400

    except mysql.connector.IntegrityError as err:
        conn.rollback()
        print(f"\n❌ Database Integrity Error: {err}")
        return jsonify({'error': f'Database integrity error: {str(err)}'}), 500

    except mysql.connector.Error as err:
        conn.rollback()
        print(f"\n❌ Database Error: {err}")
        traceback.print_exc()
        return jsonify({'error': f'Database error: {str(err)}'}), 500

    except Exception as e:
        conn.rollback()
        print(f"\n❌ Unexpected Error: {e}")
        traceback.print_exc()
        return jsonify({'error': f'Server error: {str(e)}'}), 500

    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()
        print("\n🔒 Database connection closed\n")


# ============================================
# ✅ DELETE STOCK ADJUSTMENT
# ============================================
@adjustment_bp.route('/delete_adjustment/<int:adjustment_id>', methods=['DELETE'])
@jwt_required()
@role_required('admin', 'manager')
def delete_adjustment(adjustment_id):
    """
    Delete Stock Adjustment
    Properly reverses all stock changes before deletion
    """
    conn = None
    cursor = None
    
    try:
        conn = get_db_connection()
        if not conn:
            return jsonify({'error': 'Database connection failed'}), 500

        cursor = conn.cursor(dictionary=True, buffered=True)

        conn.start_transaction()
        print(f"\n🔄 Transaction started for deleting adjustment {adjustment_id}")

        # Get adjustment details
        cursor.execute("""
            SELECT 
                adjustment_id,
                adjustment_code,
                warehouse_id,
                store_id,
                adjustment_type
            FROM stock_adjustments
            WHERE adjustment_id = %s
        """, (adjustment_id,))

        adjustment = cursor.fetchone()

        if not adjustment:
            return jsonify({'error': 'Adjustment not found'}), 404

        print(f"📝 Deleting adjustment: {adjustment['adjustment_code']}")

        warehouse_id = adjustment['warehouse_id']
        store_id = adjustment['store_id']

        # Get adjustment items
        cursor.execute("""
            SELECT 
                product_id,
                variation_id,
                batch_id,
                quantity,
                adjustment_type
            FROM stock_adjustment_items
            WHERE adjustment_id = %s
        """, (adjustment_id,))

        items = cursor.fetchall()
        print(f"📦 Found {len(items)} items to reverse")

        # Reverse stock changes
        for idx, item in enumerate(items, 1):
            product_id = item['product_id']
            variation_id = item['variation_id']
            batch_id = item['batch_id']
            quantity = safe_float(item['quantity'])
            adjustment_type = item['adjustment_type']

            print(f"\n  📦 Reversing item {idx}:")
            print(f"    Product ID: {product_id}, Batch ID: {batch_id}")
            print(f"    Quantity: {quantity}, Original Type: {adjustment_type}")

            # Reverse product_batches
            if batch_id:
                if adjustment_type == 'addition':
                    cursor.execute("""
                        UPDATE product_batches
                        SET remaining_quantity = remaining_quantity - %s
                        WHERE batch_id = %s
                    """, (quantity, batch_id))
                    print(f"    ✅ Reversed batch: -{quantity}")
                
                elif adjustment_type == 'subtraction':
                    cursor.execute("""
                        UPDATE product_batches
                        SET remaining_quantity = remaining_quantity + %s
                        WHERE batch_id = %s
                    """, (quantity, batch_id))
                    print(f"    ✅ Reversed batch: +{quantity}")

            # Reverse warehouse_stock
            cursor.execute("""
                SELECT id, quantity FROM warehouse_stock
                WHERE store_id = %s AND warehouse_id = %s
                  AND product_id = %s
                  AND (variation_id IS NULL AND %s IS NULL OR variation_id = %s)
                  AND (batch_id IS NULL AND %s IS NULL OR batch_id = %s)
            """, (store_id, warehouse_id, product_id,
                  variation_id, variation_id, batch_id, batch_id))
            
            existing_stock = cursor.fetchone()

            if existing_stock:
                current_qty = safe_float(existing_stock['quantity'])

                if adjustment_type == 'addition':
                    new_qty = current_qty - quantity
                    
                    if new_qty < 0:
                        raise ValueError(
                            f"Cannot reverse: negative stock would result. " +
                            f"Current: {current_qty}, To subtract: {quantity}"
                        )
                    
                    cursor.execute("""
                        UPDATE warehouse_stock
                        SET quantity = %s, updated_at = NOW()
                        WHERE id = %s
                    """, (new_qty, existing_stock['id']))
                    
                    print(f"    ✅ Reversed warehouse stock: {current_qty} - {quantity} = {new_qty}")
                
                elif adjustment_type == 'subtraction':
                    new_qty = current_qty + quantity
                    
                    cursor.execute("""
                        UPDATE warehouse_stock
                        SET quantity = %s, updated_at = NOW()
                        WHERE id = %s
                    """, (new_qty, existing_stock['id']))
                    
                    print(f"    ✅ Reversed warehouse stock: {current_qty} + {quantity} = {new_qty}")

        # Delete adjustment items
        cursor.execute("""
            DELETE FROM stock_adjustment_items
            WHERE adjustment_id = %s
        """, (adjustment_id,))
        
        print(f"✅ Deleted adjustment items")

        # Delete adjustment
        cursor.execute("""
            DELETE FROM stock_adjustments
            WHERE adjustment_id = %s
        """, (adjustment_id,))
        
        print(f"✅ Deleted adjustment record")

        # Commit transaction
        conn.commit()
        print(f"\n✅ Transaction committed successfully!")

        print("\n" + "=" * 80)
        print(f"✅ SUCCESS: Adjustment {adjustment['adjustment_code']} deleted and reversed")
        print("=" * 80)

        return jsonify({
            'success': True,
            'message': 'Adjustment deleted successfully',
            'adjustment_id': adjustment_id,
            'adjustment_code': adjustment['adjustment_code']
        }), 200

    except ValueError as err:
        conn.rollback()
        print(f"\n❌ Validation Error: {err}")
        return jsonify({'error': str(err)}), 400

    except mysql.connector.Error as err:
        conn.rollback()
        print(f"\n❌ Database Error: {err}")
        traceback.print_exc()
        return jsonify({'error': f'Database error: {str(err)}'}), 500

    except Exception as e:
        conn.rollback()
        print(f"\n❌ Unexpected Error: {e}")
        traceback.print_exc()
        return jsonify({'error': f'Server error: {str(e)}'}), 500

    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()
        print("\n🔒 Database connection closed\n")


# ============================================
# ✅ GET BATCH STOCK INFO
# ============================================        
@adjustment_bp.route('/get_batch_stock/<int:batch_id>', methods=['GET'])
@jwt_required()
@role_required('admin', 'cashier')
def get_batch_stock(batch_id):
    """
    Get current remaining quantity for a specific batch
    """
    conn = None
    cursor = None
    
    try:
        conn = get_db_connection()
        if not conn:
            return jsonify({'error': 'Database connection failed'}), 500

        cursor = conn.cursor(dictionary=True, buffered=True)

        # Get batch stock with GRN and supplier info
        cursor.execute("""
            SELECT 
                pb.batch_id,
                pb.batch_number,
                pb.remaining_quantity,
                pb.quantity as original_quantity,
                pb.cost,
                pb.price,
                pb.expiration_date,
                g.grn_code,
                s.supplier_name
            FROM product_batches pb
            LEFT JOIN grn g ON pb.grn_id = g.grn_id
            LEFT JOIN suppliers s ON g.supplier_id = s.id
            WHERE pb.batch_id = %s
        """, (batch_id,))

        batch = cursor.fetchone()

        if not batch:
            return jsonify({'error': 'Batch not found'}), 404

        return jsonify({
            'success': True,
            'batch_id': batch['batch_id'],
            'batch_number': batch['batch_number'],
            'remaining_quantity': float(batch['remaining_quantity'] or 0),
            'original_quantity': float(batch['original_quantity'] or 0),
            'cost': float(batch['cost'] or 0),
            'price': float(batch['price'] or 0),
            'expiration_date': batch['expiration_date'].strftime('%Y-%m-%d') if batch['expiration_date'] else None,
            'grn_code': batch['grn_code'],
            'supplier_name': batch['supplier_name']
        }), 200

    except Exception as e:
        print(f"❌ Error fetching batch stock: {e}")
        traceback.print_exc()
        return jsonify({'error': str(e)}), 500

    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()