from flask import Blueprint, request, jsonify
from flask_jwt_extended import jwt_required, get_jwt_identity
from db.db import get_db_connection
from config.auth import role_required
import mysql.connector
import traceback
from datetime import datetime
import os
from werkzeug.utils import secure_filename
from flask import send_file
from datetime import datetime, date

grn_bp = Blueprint('grn', __name__)


def safe_float(value, default=0.0):
    """Safely convert to float"""
    try:
        return float(value) if value not in (None, '', 'null') else default
    except (ValueError, TypeError):
        return default


def safe_int(value, default=0):
    """Safely convert to int"""
    try:
        return int(value) if value not in (None, '', 'null') else default
    except (ValueError, TypeError):
        return default


def generate_grn_code(cursor):
    """Generate unique GRN code in format: GRN-YYYY-NNNNN"""
    year = datetime.now().year
    
    cursor.execute("""
        SELECT grn_code FROM grn 
        WHERE grn_code LIKE %s 
        ORDER BY grn_id DESC LIMIT 1
    """, (f"GRN-{year}-%",))
    
    last_grn = cursor.fetchone()
    
    if last_grn:
        last_number = int(last_grn['grn_code'].split('-')[-1])
        new_number = last_number + 1
    else:
        new_number = 1
    
    return f"GRN-{year}-{new_number:05d}"


@grn_bp.route('/create_grn', methods=['POST'])
@jwt_required()
@role_required('admin', 'manager')
def create_grn():
    """
    CREATE GRN - COMPLETE VERSION WITH BATCH DISCOUNT RULES

    This endpoint creates GRN records with:
    - Warehouse and supplier from REQUEST (changeable from PO)
    - Item-level pricing (cost, price, discount, tax)
    - Order-level tax percentage
    - Order-level discount (fixed amount)
    - Payment information (status, type, amounts)
    - Expiration dates for items
    - Complete financial tracking
    - Per-item selling discount rules per payment method

    Stock is NOT updated here - that happens during approval.

    SAVES TO:
    - grn                       (header with order tax, discount, payment)
    - grn_items                 (with ALL pricing details + expiration_date)
    - product_batch_discounts   (selling discount rules per payment method)
                                 NOTE: batch_id is NULL at this stage —
                                 updated to real batch_id during GRN approval
    - Updates purchase_orders.grn_status
    - Creates activity log

    DOES NOT TOUCH:
    - product_batches  (created during approval)
    - warehouse_stock  (updated during approval)
    """

    try:
        data = request.get_json(force=True)
        if not data:
            return jsonify({'error': 'No data provided'}), 400
    except Exception as e:
        return jsonify({'error': 'Invalid JSON data', 'message': str(e)}), 400

    print("=" * 80)
    print("📋 GRN CREATION")
    print("=" * 80)

    # ── Extract basic data ──────────────────────────────────────────────────────
    purchase_order_id  = safe_int(data.get('purchase_order_id'))
    grn_date_str       = data.get('grn_date')
    received_by_name   = data.get('received_by', '').strip()
    invoice_number     = data.get('invoice_number', '').strip() or None
    invoice_date_str   = data.get('invoice_date') or None
    vehicle_number     = data.get('vehicle_number', '').strip() or None
    driver_name        = data.get('driver_name', '').strip() or None
    driver_contact     = data.get('driver_contact', '').strip() or None
    note               = data.get('note', '').strip()
    items              = data.get('items', [])

    # ── Get warehouse and supplier from REQUEST (not from PO) ───────────────────
    warehouse_id = safe_int(data.get('warehouse_id'))
    supplier_id  = safe_int(data.get('supplier_id'))

    # ── Order-level tax, discount, payment ─────────────────────────────────────
    order_tax_percentage  = safe_float(data.get('order_tax', 0))   # Percentage
    order_discount_amount = safe_float(data.get('discount', 0))    # Fixed amount
    payment_status        = data.get('payment_status', 'Unpaid')   # Paid/Unpaid/Partial

    payment_type_raw = data.get('payment_type')
    if payment_type_raw is not None and isinstance(payment_type_raw, str):
        payment_type = payment_type_raw.strip() or None
    else:
        payment_type = None

    paid_amount = safe_float(data.get('paid_amount', 0))

    print(f"\n📊 Order-Level Financials:")
    print(f"   Order Tax        : {order_tax_percentage}%")
    print(f"   Order Discount   : LKR {order_discount_amount}")
    print(f"   Payment Status   : {payment_status}")
    print(f"   Payment Type     : {payment_type or 'N/A (Unpaid)'}")
    print(f"   Paid Amount      : LKR {paid_amount}")

    # ── Validation ──────────────────────────────────────────────────────────────
    if not purchase_order_id:
        return jsonify({'error': 'Purchase order ID is required'}), 400

    if not grn_date_str:
        return jsonify({'error': 'GRN date is required'}), 400

    if not received_by_name:
        return jsonify({'error': 'Received by is required'}), 400

    if not warehouse_id:
        return jsonify({'error': 'Warehouse is required'}), 400

    if not supplier_id:
        return jsonify({'error': 'Supplier is required'}), 400

    if not items or len(items) == 0:
        return jsonify({'error': 'At least one item is required'}), 400

    if payment_status not in ['Paid', 'Unpaid', 'Partial']:
        return jsonify({'error': 'Invalid payment status. Must be Paid, Unpaid, or Partial'}), 400

    if payment_status in ['Paid', 'Partial'] and not payment_type:
        return jsonify({'error': 'Payment type is required for paid/partial payments'}), 400

    if payment_status == 'Partial' and paid_amount <= 0:
        return jsonify({'error': 'Paid amount must be greater than 0 for partial payments'}), 400

    # ── Parse dates ─────────────────────────────────────────────────────────────
    try:
        grn_date = datetime.strptime(grn_date_str, '%Y-%m-%d').date()
    except ValueError:
        return jsonify({'error': 'Invalid GRN date format. Use YYYY-MM-DD'}), 400

    invoice_date = None
    if invoice_date_str:
        try:
            invoice_date = datetime.strptime(invoice_date_str, '%Y-%m-%d').date()
        except ValueError:
            return jsonify({'error': 'Invalid invoice date format. Use YYYY-MM-DD'}), 400

    conn = get_db_connection()
    if not conn:
        return jsonify({'error': 'Database connection failed'}), 500

    cursor = conn.cursor(dictionary=True)

    try:
        conn.start_transaction()
        print("\n🔄 Transaction started")

        # ── Step 1: Validate purchase order ─────────────────────────────────────
        print(f"\n📝 Step 1: Validating purchase order {purchase_order_id}...")
        cursor.execute("""
            SELECT
                order_id, supplier_id, warehouse_id, store_id,
                status, grn_status, grand_total
            FROM purchase_orders
            WHERE order_id = %s
        """, (purchase_order_id,))

        po = cursor.fetchone()
        if not po:
            raise ValueError(f"Purchase order {purchase_order_id} not found")

        if po['grn_status'] == 'fully_received':
            raise ValueError(f"Purchase order {purchase_order_id} is already fully received")

        store_id = po['store_id']

        print(f"  ✅ PO validated")
        print(f"  📦 Warehouse ID from REQUEST : {warehouse_id} (PO had: {po['warehouse_id']})")
        print(f"  🚚 Supplier ID from REQUEST  : {supplier_id} (PO had: {po['supplier_id']})")
        print(f"  🏪 Store ID from PO          : {store_id}")

        # Validate warehouse
        cursor.execute("SELECT id FROM warehouses WHERE id = %s AND is_active = 1", (warehouse_id,))
        if not cursor.fetchone():
            raise ValueError(f"Warehouse {warehouse_id} not found or inactive")

        # Validate supplier
        cursor.execute("SELECT id FROM suppliers WHERE id = %s", (supplier_id,))
        if not cursor.fetchone():
            raise ValueError(f"Supplier {supplier_id} not found")

        # ── Step 2: Get user ID ──────────────────────────────────────────────────
        print(f"\n📝 Step 2: Getting user ID for '{received_by_name}'...")

        current_user   = get_jwt_identity()
        received_by_id = None

        if current_user:
            cursor.execute("SELECT id FROM users WHERE id = %s", (current_user,))
            user = cursor.fetchone()
            if user:
                received_by_id = user['id']

        if not received_by_id:
            cursor.execute("SELECT id FROM users WHERE role = 'admin' LIMIT 1")
            admin = cursor.fetchone()
            if admin:
                received_by_id = admin['id']

        if not received_by_id:
            raise ValueError("No valid user found to assign as receiver")

        print(f"  ✅ Receiver user ID: {received_by_id}")

        # ── Step 3: Generate GRN code ────────────────────────────────────────────
        print(f"\n📝 Step 3: Generating GRN code...")
        grn_code = generate_grn_code(cursor)
        print(f"  ✅ GRN Code: {grn_code}")

        # ── Step 4: Validate payment method IDs in discount_rules ───────────────
        # Collect all unique payment_method_ids from all items up front
        # and validate in one query to avoid per-item roundtrips.
        print(f"\n📝 Step 4: Validating payment method IDs in discount rules...")

        all_pm_ids = set()
        for item in items:
            for rule in item.get('discount_rules', []):
                pm_id = safe_int(rule.get('payment_method_id'))
                if pm_id:
                    all_pm_ids.add(pm_id)

        valid_pm_ids = set()
        if all_pm_ids:
            fmt = ','.join(['%s'] * len(all_pm_ids))
            cursor.execute(
                f"SELECT id FROM payment_methods WHERE id IN ({fmt}) AND is_active = 1",
                tuple(all_pm_ids)
            )
            valid_pm_ids = {row['id'] for row in cursor.fetchall()}

            invalid_ids = all_pm_ids - valid_pm_ids
            if invalid_ids:
                # Non-fatal: warn and skip invalid IDs rather than aborting the GRN
                print(f"  ⚠️  Invalid / inactive payment method IDs will be skipped: {invalid_ids}")

        print(f"  ✅ Valid payment method IDs: {valid_pm_ids}")

        # ── Step 5: Process items ────────────────────────────────────────────────
        print(f"\n📝 Step 5: Processing {len(items)} items...")

        total_items         = len(items)
        item_subtotal       = 0.0
        total_item_tax      = 0.0
        total_item_discount = 0.0
        grn_items_data      = []

        for idx, item in enumerate(items, 1):
            product_id   = safe_int(item.get('product_id'))
            variation_id = safe_int(item.get('variation_id')) if item.get('variation_id') else None
            ordered_qty  = safe_float(item.get('ordered_quantity', 0))
            received_qty = safe_float(item.get('received_quantity', 0))
            rejected_qty = safe_float(item.get('rejected_quantity', 0))

            # Pricing
            unit_price    = safe_float(item.get('unit_price', 0))
            batch_price   = safe_float(item.get('batch_price', 0))
            discount      = safe_float(item.get('discount', 0))
            discount_type = item.get('discount_type', 'fixed')
            product_tax   = safe_float(item.get('product_tax', 0))
            tax_type      = item.get('tax_type', 'exclusive')
            net_unit_cost = safe_float(item.get('net_unit_cost', 0))

            quality_check = item.get('quality_check', 'pending')
            item_note     = item.get('note', '').strip()
            unit          = item.get('unit', '')

            # Parse expiration date
            expiration_date_str = item.get('expiration_date')
            expiration_date     = None
            if expiration_date_str:
                try:
                    expiration_date = datetime.strptime(expiration_date_str, '%Y-%m-%d').date()
                    print(f"     📅 Expiration date: {expiration_date}")
                except (ValueError, TypeError):
                    print(f"     ⚠️  Invalid expiration date: {expiration_date_str}, skipping")

            # Filter discount rules — keep only valid active-flagged entries
            raw_rules      = item.get('discount_rules', [])
            discount_rules = []
            for rule in raw_rules:
                pm_id = safe_int(rule.get('payment_method_id'))
                if pm_id and pm_id in valid_pm_ids:
                    discount_rules.append({
                        'payment_method_id': pm_id,
                        'discount_type':     rule.get('discount_type', 'percent'),
                        'discount_rate':     safe_float(rule.get('discount_rate', 0)),
                        'is_active':         1 if rule.get('is_active') else 0,
                    })

            # Accepted quantity
            accepted_qty = received_qty - rejected_qty
            if accepted_qty < 0:
                raise ValueError(f"Item {idx}: Accepted quantity cannot be negative")

            # Discount per unit
            if discount_type == 'percentage':
                discount_per_unit = (unit_price * discount) / 100
            else:
                discount_per_unit = discount

            price_after_discount = max(unit_price - discount_per_unit, 0)

            # Tax per unit
            if tax_type == 'inclusive':
                tax_per_unit      = (price_after_discount * product_tax) / (100 + product_tax)
                subtotal_per_unit = price_after_discount
            else:
                tax_per_unit      = (price_after_discount * product_tax) / 100
                subtotal_per_unit = price_after_discount + tax_per_unit

            total_discount_amount = discount_per_unit * accepted_qty
            total_tax_amount      = tax_per_unit * accepted_qty
            item_subtotal_calc    = accepted_qty * subtotal_per_unit

            item_subtotal       += item_subtotal_calc
            total_item_discount += total_discount_amount
            total_item_tax      += total_tax_amount

            print(f"\n  📦 Item {idx}: Product {product_id}")
            print(f"     Ordered: {ordered_qty}, Received: {received_qty}, Rejected: {rejected_qty}")
            print(f"     Accepted: {accepted_qty}")
            print(f"     Cost: {unit_price}, Price: {batch_price}, Net Cost: {net_unit_cost}")
            print(f"     Discount: {total_discount_amount:.2f}, Tax: {total_tax_amount:.2f}, Subtotal: {item_subtotal_calc:.2f}")
            print(f"     Discount Rules: {len(discount_rules)} rule(s)")
            if expiration_date:
                print(f"     📅 Expiration: {expiration_date}")

            grn_items_data.append({
                'product_id':        product_id,
                'variation_id':      variation_id,
                'ordered_quantity':  ordered_qty,
                'received_quantity': received_qty,
                'rejected_quantity': rejected_qty,
                'unit_price':        unit_price,
                'batch_price':       batch_price,
                'discount':          discount,
                'discount_type':     discount_type,
                'product_tax':       product_tax,
                'tax_type':          tax_type,
                'net_unit_cost':     net_unit_cost,
                'discount_amount':   total_discount_amount,
                'tax_amount':        total_tax_amount,
                'subtotal':          item_subtotal_calc,
                'purchase_unit':     unit,
                'quality_check':     quality_check,
                'note':              item_note,
                'expiration_date':   expiration_date,
                'discount_rules':    discount_rules,   # ← carried to Step 10
            })

        # ── Step 6: Calculate order-level totals ─────────────────────────────────
        print(f"\n📝 Step 6: Calculating order-level totals...")

        subtotal                = item_subtotal
        subtotal_after_discount = max(subtotal - order_discount_amount, 0)
        order_tax_amount        = (subtotal_after_discount * order_tax_percentage) / 100
        grand_total             = subtotal_after_discount + order_tax_amount

        if payment_status == 'Paid':
            actual_paid = grand_total
            due_amount  = 0
        elif payment_status == 'Partial':
            actual_paid = paid_amount
            due_amount  = max(grand_total - paid_amount, 0)
        else:  # Unpaid
            actual_paid = 0
            due_amount  = grand_total

        print(f"\n  💰 Financial Summary:")
        print(f"     Item Subtotal           : LKR {item_subtotal:.2f}")
        print(f"     Order Discount          : -LKR {order_discount_amount:.2f}")
        print(f"     Subtotal after discount : LKR {subtotal_after_discount:.2f}")
        print(f"     Order Tax ({order_tax_percentage}%)      : +LKR {order_tax_amount:.2f}")
        print(f"     Grand Total             : LKR {grand_total:.2f}")
        print(f"     Payment Status          : {payment_status}")
        print(f"     Payment Type            : {payment_type or 'N/A (Unpaid)'}")
        print(f"     Paid Amount             : LKR {actual_paid:.2f}")
        print(f"     Due Amount              : LKR {due_amount:.2f}")

        # ── Step 7: Insert GRN header ────────────────────────────────────────────
        print(f"\n📝 Step 7: Creating GRN header...")

        cursor.execute("""
            INSERT INTO grn (
                grn_code, purchase_order_id, supplier_id, warehouse_id, store_id,
                grn_date, received_by, invoice_number, invoice_date,
                vehicle_number, driver_name, driver_contact,
                status, total_items, subtotal, tax, discount, grand_total,
                order_tax, payment_status, payment_type, paid_amount, due_amount,
                note, created_by, created_at
            ) VALUES (
                %s, %s, %s, %s, %s,
                %s, %s, %s, %s,
                %s, %s, %s,
                'pending', %s, %s, %s, %s, %s,
                %s, %s, %s, %s, %s,
                %s, %s, NOW()
            )
        """, (
            grn_code, purchase_order_id, supplier_id, warehouse_id, store_id,
            grn_date, received_by_id, invoice_number, invoice_date,
            vehicle_number, driver_name, driver_contact,
            total_items, subtotal, order_tax_amount, order_discount_amount, grand_total,
            order_tax_percentage, payment_status, payment_type, actual_paid, due_amount,
            note, received_by_id
        ))

        grn_id = cursor.lastrowid
        print(f"  ✅ GRN created with ID: {grn_id}")

        # ── Step 8: Verify grn_items table columns ───────────────────────────────
        print(f"\n📝 Step 8: Checking grn_items table structure...")

        cursor.execute("SHOW COLUMNS FROM grn_items LIKE 'batch_price'")
        if not cursor.fetchone():
            print(f"  ⚠️  batch_price column not found — adding...")
            cursor.execute("""
                ALTER TABLE grn_items
                ADD COLUMN batch_price DECIMAL(15,2) DEFAULT 0.00
                AFTER unit_price
            """)
            print(f"  ✅ batch_price column added")

        cursor.execute("SHOW COLUMNS FROM grn_items LIKE 'expiration_date'")
        if not cursor.fetchone():
            print(f"  ⚠️  expiration_date column not found — adding...")
            cursor.execute("""
                ALTER TABLE grn_items
                ADD COLUMN expiration_date DATE NULL
                AFTER note
            """)
            print(f"  ✅ expiration_date column added")

        # ── Step 9: Ensure product_batch_discounts table exists ──────────────────
        # grn_item_id is the FK here (not batch_id) because the product_batch
        # does not exist yet — it is created during GRN approval.
        # The approve_grn endpoint will UPDATE batch_id once the batch is created.
        print(f"\n📝 Step 9: Ensuring product_batch_discounts table exists...")

        cursor.execute("""
            CREATE TABLE IF NOT EXISTS `product_batch_discounts` (
                `id`                INT(11)                NOT NULL AUTO_INCREMENT,
                `grn_item_id`       INT(11)                NOT NULL
                    COMMENT 'FK to grn_items — links rule before batch is created',
                `batch_id`          INT(11)                DEFAULT NULL
                    COMMENT 'FK to product_batches — populated during GRN approval',
                `payment_method_id` INT(11)                NOT NULL,
                `discount_rate`     DECIMAL(5,2)           NOT NULL DEFAULT 0.00,
                `discount_type`     ENUM('percent','fixed') NOT NULL DEFAULT 'percent',
                `is_active`         TINYINT(1)             NOT NULL DEFAULT 1,
                `created_at`        TIMESTAMP              NOT NULL DEFAULT CURRENT_TIMESTAMP,
                `updated_at`        TIMESTAMP              NOT NULL DEFAULT CURRENT_TIMESTAMP
                                    ON UPDATE CURRENT_TIMESTAMP,
                PRIMARY KEY (`id`),
                UNIQUE KEY `ux_pbd_grn_item_pm` (`grn_item_id`, `payment_method_id`),
                KEY `idx_pbd_grn_item`       (`grn_item_id`),
                KEY `idx_pbd_batch`          (`batch_id`),
                KEY `idx_pbd_payment_method` (`payment_method_id`),
                KEY `idx_pbd_is_active`      (`is_active`),
                CONSTRAINT `fk_pbd_payment_method` FOREIGN KEY (`payment_method_id`)
                    REFERENCES `payment_methods` (`id`)
                    ON DELETE CASCADE ON UPDATE CASCADE
            ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci
              COMMENT='Per-GRN-item selling discount rules by payment method.
                       batch_id is NULL until GRN approval creates the product_batch.';
        """)

        print(f"  ✅ product_batch_discounts table ready")

        # ── Step 10: Insert GRN items + discount rules ───────────────────────────
        print(f"\n📝 Step 10: Inserting GRN items and discount rules...")

        total_discount_rules_saved = 0

        for item_data in grn_items_data:

            # 10-a: Insert grn_item row
            cursor.execute("""
                INSERT INTO grn_items (
                    grn_id, product_id, variation_id,
                    ordered_quantity, received_quantity, rejected_quantity,
                    unit_price, batch_price,
                    discount_type, product_discount,
                    tax_type, product_tax,
                    discount, tax, net_unit_cost, subtotal,
                    purchase_unit, quality_check, note,
                    expiration_date, created_at
                ) VALUES (
                    %s, %s, %s,
                    %s, %s, %s,
                    %s, %s,
                    %s, %s,
                    %s, %s,
                    %s, %s, %s, %s,
                    %s, %s, %s,
                    %s, NOW()
                )
            """, (
                grn_id,
                item_data['product_id'],
                item_data['variation_id'],
                item_data['ordered_quantity'],
                item_data['received_quantity'],
                item_data['rejected_quantity'],
                item_data['unit_price'],
                item_data['batch_price'],
                item_data['discount_type'],
                item_data['discount'],
                item_data['tax_type'],
                item_data['product_tax'],
                item_data['discount_amount'],
                item_data['tax_amount'],
                item_data['net_unit_cost'],
                item_data['subtotal'],
                item_data['purchase_unit'],
                item_data['quality_check'],
                item_data['note'],
                item_data['expiration_date'],
            ))

            grn_item_id = cursor.lastrowid
            exp_info    = f" | Exp: {item_data['expiration_date']}" if item_data['expiration_date'] else ""
            print(f"  ✅ GRN item {grn_item_id}: Product {item_data['product_id']}{exp_info}")

            # 10-b: Save discount rules for this grn_item
            # batch_id = NULL here; approve_grn will UPDATE it to the real batch_id.
            rules = item_data.get('discount_rules', [])
            if rules:
                print(f"     💾 Saving {len(rules)} discount rule(s) for grn_item {grn_item_id}...")
                for rule in rules:
                    cursor.execute("""
                        INSERT INTO product_batch_discounts
                            (grn_item_id, batch_id, payment_method_id,
                             discount_rate, discount_type, is_active)
                        VALUES
                            (%s, NULL, %s, %s, %s, %s)
                        ON DUPLICATE KEY UPDATE
                            discount_rate  = VALUES(discount_rate),
                            discount_type  = VALUES(discount_type),
                            is_active      = VALUES(is_active),
                            updated_at     = CURRENT_TIMESTAMP
                    """, (
                        grn_item_id,
                        rule['payment_method_id'],
                        rule['discount_rate'],
                        rule['discount_type'],
                        rule['is_active'],
                    ))

                    total_discount_rules_saved += 1
                    print(f"       ✅ PM {rule['payment_method_id']}: "
                          f"{rule['discount_rate']} {rule['discount_type']} "
                          f"({'active' if rule['is_active'] else 'inactive'})")
            else:
                print(f"     ℹ️  No discount rules for grn_item {grn_item_id}")

        # ── Step 11: Update purchase order GRN status ────────────────────────────
        print(f"\n📝 Step 11: Updating PO GRN status...")

        cursor.execute("""
            SELECT
                SUM(oi.quantity) as total_ordered,
                COALESCE((
                    SELECT SUM(gi.received_quantity - gi.rejected_quantity)
                    FROM grn g2
                    INNER JOIN grn_items gi ON g2.grn_id = gi.grn_id
                    WHERE g2.purchase_order_id = %s
                    AND g2.status != 'cancelled'
                ), 0) as total_received
            FROM order_items oi
            WHERE oi.order_id = %s
        """, (purchase_order_id, purchase_order_id))

        qty_result     = cursor.fetchone()
        total_ordered  = safe_float(qty_result['total_ordered'])
        total_received = safe_float(qty_result['total_received'])

        if total_received == 0:
            new_grn_status = 'not_received'
        elif total_received < total_ordered:
            new_grn_status = 'partial'
        else:
            new_grn_status = 'fully_received'

        cursor.execute("""
            UPDATE purchase_orders
            SET grn_status = %s
            WHERE order_id = %s
        """, (new_grn_status, purchase_order_id))

        print(f"  ✅ PO GRN status: {new_grn_status}")

        # ── Step 12: Activity log ────────────────────────────────────────────────
        print(f"\n📝 Step 12: Creating activity log...")

        cursor.execute("""
            INSERT INTO grn_activity_log (
                grn_id, user_id, action, description, created_at
            ) VALUES (
                %s, %s, 'GRN Created', %s, NOW()
            )
        """, (
            grn_id, received_by_id,
            (
                f"GRN {grn_code} created for PO {purchase_order_id}. "
                f"Warehouse: {warehouse_id}, Supplier: {supplier_id}, "
                f"Payment: {payment_status}, Grand Total: LKR {grand_total:.2f}, "
                f"Discount rules saved: {total_discount_rules_saved}"
            )
        ))

        print(f"  ✅ Activity logged")

        # ── Commit ───────────────────────────────────────────────────────────────
        conn.commit()
        print(f"\n✅ Transaction committed successfully!")

        print("\n" + "=" * 80)
        print("✅ GRN CREATION COMPLETE")
        print("=" * 80)
        print(f"✅ GRN Code            : {grn_code}")
        print(f"✅ GRN ID              : {grn_id}")
        print(f"✅ Warehouse           : {warehouse_id} (PO had: {po['warehouse_id']})")
        print(f"✅ Supplier            : {supplier_id} (PO had: {po['supplier_id']})")
        print(f"✅ Status              : pending")
        print(f"✅ Total Items         : {total_items}")
        print(f"✅ Items w/ expiry     : {len([i for i in grn_items_data if i['expiration_date']])}")
        print(f"✅ Discount rules saved: {total_discount_rules_saved}")
        print(f"✅ Grand Total         : LKR {grand_total:.2f}")
        print(f"✅ Payment Status      : {payment_status}")
        print(f"✅ Payment Type        : {payment_type or 'N/A (Unpaid)'}")
        print(f"✅ Due Amount          : LKR {due_amount:.2f}")
        print(f"\n⚠️  IMPORTANT:")
        print(f"   📦 product_batches          : NOT created yet")
        print(f"   📦 warehouse_stock          : NOT updated yet")
        print(f"   🔗 product_batch_discounts  : batch_id = NULL (set during approval)")
        print(f"   🔄 Stock + batch_id updated when GRN is APPROVED")
        print("=" * 80)

        return jsonify({
            'success':  True,
            'message':  f'GRN {grn_code} created successfully. Awaiting approval.',
            'grn_id':   grn_id,
            'grn_code': grn_code,
            'data': {
                'grn_id':                   grn_id,
                'grn_code':                 grn_code,
                'purchase_order_id':        purchase_order_id,
                'warehouse_id':             warehouse_id,
                'supplier_id':              supplier_id,
                'status':                   'pending',
                'total_items':              total_items,
                'subtotal':                 subtotal,
                'order_tax':                order_tax_amount,
                'order_tax_percentage':     order_tax_percentage,
                'discount':                 order_discount_amount,
                'grand_total':              grand_total,
                'payment_status':           payment_status,
                'payment_type':             payment_type,
                'paid_amount':              actual_paid,
                'due_amount':               due_amount,
                'grn_status':               new_grn_status,
                'discount_rules_saved':     total_discount_rules_saved,
                'stock_updated':            False,
                'requires_approval':        True,
            }
        }), 201

    except mysql.connector.IntegrityError as err:
        conn.rollback()
        print(f"\n❌ Database Integrity Error: {err}")
        traceback.print_exc()
        return jsonify({'error': f'Database integrity error: {str(err)}'}), 500

    except ValueError as err:
        conn.rollback()
        print(f"\n❌ Validation Error: {err}")
        return jsonify({'error': str(err)}), 400

    except mysql.connector.Error as err:
        conn.rollback()
        print(f"\n❌ Database Error: {err}")
        traceback.print_exc()
        return jsonify({'error': f'Database error: {str(err)}'}), 500

    except Exception as e:
        conn.rollback()
        print(f"\n❌ Unexpected Error: {e}")
        traceback.print_exc()
        return jsonify({'error': f'Server error: {str(e)}'}), 500

    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()
        print("\n🔒 Database connection closed\n")
        
        
@grn_bp.route('/update_grn/<int:grn_id>', methods=['PUT'])
@jwt_required()
@role_required('admin', 'manager')
def update_grn(grn_id):
    """
    UPDATE GRN - COMPLETE VERSION WITH BATCH DISCOUNT RULES

    Features:
    - Updates warehouse and supplier
    - Updates all pricing details (unit_price, batch_price)
    - Updates payment information
    - Updates expiration dates
    - Supports editing PENDING and COMPLETED GRNs
    - Reverses and recreates stock for COMPLETED GRNs
    - Saves per-item selling discount rules to product_batch_discounts
    - Automatic column creation if missing

    RESTRICTIONS:
    - Only PENDING and COMPLETED GRNs can be edited
    - CANCELLED GRNs cannot be edited

    DISCOUNT RULES FLOW:
    - PENDING GRN  : delete old rules by grn_item_id → insert new rules (batch_id = NULL)
    - COMPLETED GRN: delete old rules by batch_id   → insert new rules with real batch_id
    """

    try:
        data = request.get_json(force=True)
        if not data:
            return jsonify({'error': 'No data provided'}), 400
    except Exception as e:
        return jsonify({'error': 'Invalid JSON data', 'message': str(e)}), 400

    print("=" * 80)
    print(f"✏️ GRN UPDATE - GRN ID: {grn_id}")
    print("=" * 80)

    # ── Extract basic data ──────────────────────────────────────────────────────
    grn_date_str     = data.get('grn_date')
    received_by_name = data.get('received_by', '').strip()
    invoice_number   = data.get('invoice_number', '').strip() or None
    invoice_date_str = data.get('invoice_date') or None
    vehicle_number   = data.get('vehicle_number', '').strip() or None
    driver_name      = data.get('driver_name', '').strip() or None
    driver_contact   = data.get('driver_contact', '').strip() or None
    note             = data.get('note', '').strip()
    items            = data.get('items', [])

    # ── Get warehouse and supplier from request ─────────────────────────────────
    warehouse_id = safe_int(data.get('warehouse_id'))
    supplier_id  = safe_int(data.get('supplier_id'))

    # ── Order-level tax, discount, payment ─────────────────────────────────────
    order_tax_percentage  = safe_float(data.get('order_tax', 0))
    order_discount_amount = safe_float(data.get('discount', 0))
    payment_status        = data.get('payment_status', 'Unpaid')

    payment_type_raw = data.get('payment_type')
    if payment_type_raw is not None and isinstance(payment_type_raw, str):
        payment_type = payment_type_raw.strip() or None
    else:
        payment_type = None

    paid_amount = safe_float(data.get('paid_amount', 0))

    print(f"\n📊 Update Data:")
    print(f"   Warehouse ID  : {warehouse_id}")
    print(f"   Supplier ID   : {supplier_id}")
    print(f"   Order Tax     : {order_tax_percentage}%")
    print(f"   Order Discount: LKR {order_discount_amount}")
    print(f"   Payment Status: {payment_status}")
    print(f"   Payment Type  : {payment_type or 'N/A (Unpaid)'}")
    print(f"   Items Count   : {len(items)}")

    # ── Validation ──────────────────────────────────────────────────────────────
    if not grn_date_str:
        return jsonify({'error': 'GRN date is required'}), 400

    if not received_by_name:
        return jsonify({'error': 'Received by is required'}), 400

    if not warehouse_id:
        return jsonify({'error': 'Warehouse is required'}), 400

    if not supplier_id:
        return jsonify({'error': 'Supplier is required'}), 400

    if not items or len(items) == 0:
        return jsonify({'error': 'At least one item is required'}), 400

    if payment_status not in ['Paid', 'Unpaid', 'Partial']:
        return jsonify({'error': 'Invalid payment status. Must be Paid, Unpaid, or Partial'}), 400

    if payment_status in ['Paid', 'Partial'] and not payment_type:
        return jsonify({'error': 'Payment type is required for paid/partial payments'}), 400

    if payment_status == 'Partial' and paid_amount <= 0:
        return jsonify({'error': 'Paid amount must be greater than 0 for partial payments'}), 400

    # ── Parse dates ─────────────────────────────────────────────────────────────
    try:
        grn_date = datetime.strptime(grn_date_str, '%Y-%m-%d').date()
    except ValueError:
        return jsonify({'error': 'Invalid GRN date format. Use YYYY-MM-DD'}), 400

    invoice_date = None
    if invoice_date_str:
        try:
            invoice_date = datetime.strptime(invoice_date_str, '%Y-%m-%d').date()
        except ValueError:
            return jsonify({'error': 'Invalid invoice date format. Use YYYY-MM-DD'}), 400

    conn = get_db_connection()
    if not conn:
        return jsonify({'error': 'Database connection failed'}), 500

    cursor = conn.cursor(dictionary=True)

    try:
        conn.start_transaction()
        print("\n🔄 Transaction started")

        # ── Step 1: Get existing GRN ─────────────────────────────────────────────
        print(f"\n📝 Step 1: Fetching existing GRN {grn_id}...")
        cursor.execute("SELECT * FROM grn WHERE grn_id = %s", (grn_id,))
        existing_grn = cursor.fetchone()

        if not existing_grn:
            raise ValueError(f"GRN {grn_id} not found")

        if existing_grn['status'] == 'cancelled':
            raise ValueError("Cannot update cancelled GRN. Only pending and completed GRNs can be edited.")

        is_completed      = existing_grn['status'] == 'completed'
        purchase_order_id = existing_grn['purchase_order_id']
        old_warehouse_id  = existing_grn['warehouse_id']
        old_supplier_id   = existing_grn['supplier_id']
        store_id          = existing_grn['store_id']
        grn_code          = existing_grn['grn_code']

        if is_completed:
            print(f"  ⚠️  WARNING: Editing COMPLETED GRN {grn_code} — stock will be REVERSED and RECREATED!")
        else:
            print(f"  ✅ GRN {grn_code} found (status: {existing_grn['status']})")

        print(f"  📦 Old Warehouse: {old_warehouse_id}, New: {warehouse_id}")
        print(f"  🚚 Old Supplier : {old_supplier_id}, New: {supplier_id}")

        # Validate warehouse
        cursor.execute("SELECT id FROM warehouses WHERE id = %s AND is_active = 1", (warehouse_id,))
        if not cursor.fetchone():
            raise ValueError(f"Warehouse {warehouse_id} not found or inactive")

        # Validate supplier
        cursor.execute("SELECT id FROM suppliers WHERE id = %s", (supplier_id,))
        if not cursor.fetchone():
            raise ValueError(f"Supplier {supplier_id} not found")

        # ── Step 2: Get user ID ──────────────────────────────────────────────────
        print(f"\n📝 Step 2: Getting user ID...")

        current_user   = get_jwt_identity()
        received_by_id = None

        if current_user:
            cursor.execute("SELECT id FROM users WHERE id = %s", (current_user,))
            user = cursor.fetchone()
            if user:
                received_by_id = user['id']

        if not received_by_id:
            cursor.execute("SELECT id FROM users WHERE role = 'admin' LIMIT 1")
            admin = cursor.fetchone()
            if admin:
                received_by_id = admin['id']

        if not received_by_id:
            raise ValueError("No valid user found")

        print(f"  ✅ User ID: {received_by_id}")

        # ── Step 3: Validate payment method IDs in discount_rules ───────────────
        # Collect all unique payment_method_ids from all items up front.
        print(f"\n📝 Step 3: Validating payment method IDs in discount rules...")

        all_pm_ids = set()
        for item in items:
            for rule in item.get('discount_rules', []):
                pm_id = safe_int(rule.get('payment_method_id'))
                if pm_id:
                    all_pm_ids.add(pm_id)

        valid_pm_ids = set()
        if all_pm_ids:
            fmt = ','.join(['%s'] * len(all_pm_ids))
            cursor.execute(
                f"SELECT id FROM payment_methods WHERE id IN ({fmt}) AND is_active = 1",
                tuple(all_pm_ids)
            )
            valid_pm_ids = {row['id'] for row in cursor.fetchall()}

            invalid_ids = all_pm_ids - valid_pm_ids
            if invalid_ids:
                print(f"  ⚠️  Invalid / inactive payment method IDs will be skipped: {invalid_ids}")

        print(f"  ✅ Valid payment method IDs: {valid_pm_ids}")

        # ── Step 4: Check product_batch_discounts table ──────────────────────────
        print(f"\n📝 Step 4: Checking product_batch_discounts table...")

        pbd_table_exists    = False
        pbd_has_grn_item_id = False

        cursor.execute("""
            SELECT COUNT(*) as cnt FROM INFORMATION_SCHEMA.TABLES
            WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'product_batch_discounts'
        """)
        pbd_table_exists = cursor.fetchone()['cnt'] > 0

        if pbd_table_exists:
            cursor.execute("""
                SELECT COUNT(*) as cnt FROM INFORMATION_SCHEMA.COLUMNS
                WHERE TABLE_SCHEMA = DATABASE()
                  AND TABLE_NAME   = 'product_batch_discounts'
                  AND COLUMN_NAME  = 'grn_item_id'
            """)
            pbd_has_grn_item_id = cursor.fetchone()['cnt'] > 0

        print(f"  {'✅' if pbd_table_exists else '⚠️ '} product_batch_discounts: "
              f"{'EXISTS' if pbd_table_exists else 'NOT FOUND'}")
        print(f"  {'✅' if pbd_has_grn_item_id else '⚠️ '} grn_item_id column: "
              f"{'EXISTS' if pbd_has_grn_item_id else 'NOT FOUND — discount rules will be skipped'}")

        # ── Step 5: Reverse stock if COMPLETED ───────────────────────────────────
        if is_completed:
            print(f"\n📝 Step 5: REVERSING stock for completed GRN...")

            cursor.execute("SELECT * FROM grn_items WHERE grn_id = %s", (grn_id,))
            old_items = cursor.fetchall()

            for idx, old_item in enumerate(old_items, 1):
                if not old_item.get('batch_id'):
                    print(f"  ⚠️  Item {idx}: No batch_id, skipping stock reversal")
                    continue

                batch_id     = old_item['batch_id']
                accepted_qty = old_item['received_quantity'] - old_item['rejected_quantity']

                print(f"\n  🔄 Reversing Item {idx}:")
                print(f"     Product  : {old_item['product_id']}")
                print(f"     Batch ID : {batch_id}")
                print(f"     Qty      : {accepted_qty}")

                # Delete warehouse stock for this batch
                cursor.execute("""
                    DELETE FROM warehouse_stock
                    WHERE batch_id = %s AND store_id = %s AND warehouse_id = %s
                """, (batch_id, store_id, old_warehouse_id))
                print(f"     ✅ Deleted {cursor.rowcount} warehouse_stock record(s)")

                # Delete discount rules tied to this batch
                if pbd_table_exists:
                    cursor.execute("""
                        DELETE FROM product_batch_discounts WHERE batch_id = %s
                    """, (batch_id,))
                    print(f"     ✅ Deleted {cursor.rowcount} discount rule(s) for batch {batch_id}")

                # Delete the batch itself
                cursor.execute("DELETE FROM product_batches WHERE batch_id = %s", (batch_id,))
                print(f"     ✅ Deleted {cursor.rowcount} product_batch record(s)")

            print(f"\n  ✅ Stock reversal complete for {len(old_items)} items")

        # ── Step 6: Process new items ────────────────────────────────────────────
        print(f"\n📝 Step 6: Processing {len(items)} new items...")

        total_items         = len(items)
        item_subtotal       = 0.0
        total_item_tax      = 0.0
        total_item_discount = 0.0
        grn_items_data      = []

        for idx, item in enumerate(items, 1):
            product_id   = safe_int(item.get('product_id'))
            variation_id = safe_int(item.get('variation_id')) if item.get('variation_id') else None
            ordered_qty  = safe_float(item.get('ordered_quantity', 0))
            received_qty = safe_float(item.get('received_quantity', 0))
            rejected_qty = safe_float(item.get('rejected_quantity', 0))

            unit_price    = safe_float(item.get('unit_price', 0))
            batch_price   = safe_float(item.get('batch_price', 0))
            discount      = safe_float(item.get('discount', 0))
            discount_type = item.get('discount_type', 'fixed')
            product_tax   = safe_float(item.get('product_tax', 0))
            tax_type      = item.get('tax_type', 'exclusive')
            net_unit_cost = safe_float(item.get('net_unit_cost', 0))

            quality_check = item.get('quality_check', 'pending')
            item_note     = item.get('note', '').strip()
            unit          = item.get('unit', '')

            # Parse expiration date
            expiration_date_str = item.get('expiration_date')
            expiration_date     = None
            if expiration_date_str:
                try:
                    expiration_date = datetime.strptime(expiration_date_str, '%Y-%m-%d').date()
                except (ValueError, TypeError):
                    print(f"     ⚠️  Invalid expiration date: {expiration_date_str}, skipping")

            # Filter discount rules
            raw_rules      = item.get('discount_rules', [])
            discount_rules = []
            for rule in raw_rules:
                pm_id = safe_int(rule.get('payment_method_id'))
                if pm_id and pm_id in valid_pm_ids:
                    discount_rules.append({
                        'payment_method_id': pm_id,
                        'discount_type':     rule.get('discount_type', 'percent'),
                        'discount_rate':     safe_float(rule.get('discount_rate', 0)),
                        'is_active':         1 if rule.get('is_active') else 0,
                    })

            accepted_qty = received_qty - rejected_qty
            if accepted_qty < 0:
                raise ValueError(f"Item {idx}: Accepted quantity cannot be negative")

            # Discount per unit
            if discount_type == 'percentage':
                discount_per_unit = (unit_price * discount) / 100
            else:
                discount_per_unit = discount

            price_after_discount = max(unit_price - discount_per_unit, 0)

            # Tax per unit
            if tax_type == 'inclusive':
                tax_per_unit      = (price_after_discount * product_tax) / (100 + product_tax)
                subtotal_per_unit = price_after_discount
            else:
                tax_per_unit      = (price_after_discount * product_tax) / 100
                subtotal_per_unit = price_after_discount + tax_per_unit

            total_discount_amount = discount_per_unit * accepted_qty
            total_tax_amount      = tax_per_unit * accepted_qty
            item_subtotal_calc    = accepted_qty * subtotal_per_unit

            item_subtotal       += item_subtotal_calc
            total_item_discount += total_discount_amount
            total_item_tax      += total_tax_amount

            print(f"\n  📦 Item {idx}: Product {product_id}")
            print(f"     Accepted: {accepted_qty}, Cost: {unit_price}, Net Cost: {net_unit_cost}")
            print(f"     Discount rules: {len(discount_rules)} rule(s)")
            if expiration_date:
                print(f"     📅 Expiration: {expiration_date}")

            grn_items_data.append({
                'product_id':        product_id,
                'variation_id':      variation_id,
                'ordered_quantity':  ordered_qty,
                'received_quantity': received_qty,
                'rejected_quantity': rejected_qty,
                'unit_price':        unit_price,
                'batch_price':       batch_price,
                'discount':          discount,
                'discount_type':     discount_type,
                'product_tax':       product_tax,
                'tax_type':          tax_type,
                'net_unit_cost':     net_unit_cost,
                'discount_amount':   total_discount_amount,
                'tax_amount':        total_tax_amount,
                'subtotal':          item_subtotal_calc,
                'purchase_unit':     unit,
                'quality_check':     quality_check,
                'note':              item_note,
                'expiration_date':   expiration_date,
                'discount_rules':    discount_rules,   # ← carried forward
            })

        # ── Step 7: Calculate order-level totals ─────────────────────────────────
        print(f"\n📝 Step 7: Calculating order-level totals...")

        subtotal                = item_subtotal
        subtotal_after_discount = max(subtotal - order_discount_amount, 0)
        order_tax_amount        = (subtotal_after_discount * order_tax_percentage) / 100
        grand_total             = subtotal_after_discount + order_tax_amount

        if payment_status == 'Paid':
            actual_paid = grand_total
            due_amount  = 0
        elif payment_status == 'Partial':
            actual_paid = paid_amount
            due_amount  = max(grand_total - paid_amount, 0)
        else:
            actual_paid = 0
            due_amount  = grand_total

        print(f"\n  💰 Financial Summary:")
        print(f"     Item Subtotal           : LKR {item_subtotal:.2f}")
        print(f"     Order Discount          : -LKR {order_discount_amount:.2f}")
        print(f"     Subtotal after discount : LKR {subtotal_after_discount:.2f}")
        print(f"     Order Tax ({order_tax_percentage}%)      : +LKR {order_tax_amount:.2f}")
        print(f"     Grand Total             : LKR {grand_total:.2f}")
        print(f"     Paid Amount             : LKR {actual_paid:.2f}")
        print(f"     Due Amount              : LKR {due_amount:.2f}")

        # ── Step 8: Update GRN header ────────────────────────────────────────────
        print(f"\n📝 Step 8: Updating GRN header...")

        cursor.execute("""
            UPDATE grn SET
                warehouse_id   = %s,
                supplier_id    = %s,
                grn_date       = %s,
                received_by    = %s,
                invoice_number = %s,
                invoice_date   = %s,
                vehicle_number = %s,
                driver_name    = %s,
                driver_contact = %s,
                total_items    = %s,
                subtotal       = %s,
                tax            = %s,
                discount       = %s,
                grand_total    = %s,
                order_tax      = %s,
                payment_status = %s,
                payment_type   = %s,
                paid_amount    = %s,
                due_amount     = %s,
                note           = %s,
                updated_at     = NOW()
            WHERE grn_id = %s
        """, (
            warehouse_id, supplier_id,
            grn_date, received_by_id, invoice_number, invoice_date,
            vehicle_number, driver_name, driver_contact,
            total_items, subtotal, order_tax_amount, order_discount_amount, grand_total,
            order_tax_percentage, payment_status, payment_type, actual_paid, due_amount,
            note, grn_id
        ))

        print(f"  ✅ GRN header updated")
        print(f"  ✅ Warehouse : {old_warehouse_id} → {warehouse_id}")
        print(f"  ✅ Supplier  : {old_supplier_id} → {supplier_id}")

        # ── Step 9: Delete old GRN items (and their discount rules) ──────────────
        print(f"\n📝 Step 9: Deleting old GRN items and their discount rules...")

        # For PENDING GRNs the old discount rules are keyed by grn_item_id.
        # Collect old grn_item_ids before deleting.
        if pbd_table_exists and pbd_has_grn_item_id and not is_completed:
            cursor.execute(
                "SELECT grn_item_id FROM grn_items WHERE grn_id = %s", (grn_id,)
            )
            old_item_ids = [row['grn_item_id'] for row in cursor.fetchall()]

            if old_item_ids:
                fmt = ','.join(['%s'] * len(old_item_ids))
                cursor.execute(
                    f"DELETE FROM product_batch_discounts WHERE grn_item_id IN ({fmt})",
                    tuple(old_item_ids)
                )
                print(f"  ✅ Deleted {cursor.rowcount} old discount rule(s) (by grn_item_id)")

        cursor.execute("DELETE FROM grn_items WHERE grn_id = %s", (grn_id,))
        print(f"  ✅ Old GRN items deleted")

        # ── Step 10: Verify grn_items columns ────────────────────────────────────
        print(f"\n📝 Step 10: Checking grn_items table structure...")

        cursor.execute("SHOW COLUMNS FROM grn_items LIKE 'batch_price'")
        if not cursor.fetchone():
            cursor.execute("""
                ALTER TABLE grn_items
                ADD COLUMN batch_price DECIMAL(15,2) DEFAULT 0.00 AFTER unit_price
            """)
            print(f"  ✅ batch_price column added")

        cursor.execute("SHOW COLUMNS FROM grn_items LIKE 'expiration_date'")
        if not cursor.fetchone():
            cursor.execute("""
                ALTER TABLE grn_items
                ADD COLUMN expiration_date DATE NULL AFTER note
            """)
            print(f"  ✅ expiration_date column added")

        # ── Step 11: Insert updated items + discount rules (PENDING) ─────────────
        print(f"\n📝 Step 11: Inserting updated GRN items...")

        total_discount_rules_saved = 0

        for item_data in grn_items_data:
            # Insert grn_item
            cursor.execute("""
                INSERT INTO grn_items (
                    grn_id, product_id, variation_id,
                    ordered_quantity, received_quantity, rejected_quantity,
                    unit_price, batch_price,
                    discount_type, product_discount,
                    tax_type, product_tax,
                    discount, tax, net_unit_cost, subtotal,
                    purchase_unit, quality_check, note,
                    expiration_date, created_at
                ) VALUES (
                    %s, %s, %s,
                    %s, %s, %s,
                    %s, %s,
                    %s, %s,
                    %s, %s,
                    %s, %s, %s, %s,
                    %s, %s, %s,
                    %s, NOW()
                )
            """, (
                grn_id,
                item_data['product_id'],
                item_data['variation_id'],
                item_data['ordered_quantity'],
                item_data['received_quantity'],
                item_data['rejected_quantity'],
                item_data['unit_price'],
                item_data['batch_price'],
                item_data['discount_type'],
                item_data['discount'],
                item_data['tax_type'],
                item_data['product_tax'],
                item_data['discount_amount'],
                item_data['tax_amount'],
                item_data['net_unit_cost'],
                item_data['subtotal'],
                item_data['purchase_unit'],
                item_data['quality_check'],
                item_data['note'],
                item_data['expiration_date'],
            ))

            grn_item_id = cursor.lastrowid
            exp_info    = f" | Exp: {item_data['expiration_date']}" if item_data['expiration_date'] else ""
            print(f"  ✅ GRN item {grn_item_id}: Product {item_data['product_id']}{exp_info}")

            # Save discount rules for PENDING GRN (batch_id = NULL, set during approval)
            if not is_completed and pbd_table_exists and pbd_has_grn_item_id:
                rules = item_data.get('discount_rules', [])
                if rules:
                    print(f"     💾 Saving {len(rules)} discount rule(s) for grn_item {grn_item_id}...")
                    for rule in rules:
                        cursor.execute("""
                            INSERT INTO product_batch_discounts
                                (grn_item_id, batch_id, payment_method_id,
                                 discount_rate, discount_type, is_active)
                            VALUES
                                (%s, NULL, %s, %s, %s, %s)
                            ON DUPLICATE KEY UPDATE
                                discount_rate  = VALUES(discount_rate),
                                discount_type  = VALUES(discount_type),
                                is_active      = VALUES(is_active),
                                updated_at     = CURRENT_TIMESTAMP
                        """, (
                            grn_item_id,
                            rule['payment_method_id'],
                            rule['discount_rate'],
                            rule['discount_type'],
                            rule['is_active'],
                        ))
                        total_discount_rules_saved += 1
                        print(f"       ✅ PM {rule['payment_method_id']}: "
                              f"{rule['discount_rate']} {rule['discount_type']} "
                              f"({'active' if rule['is_active'] else 'inactive'})")

        # ── Step 12: Recreate batches + stock + discount rules (COMPLETED) ────────
        if is_completed:
            print(f"\n📝 Step 12: RECREATING batches, stock, and discount rules (completed GRN)...")

            cursor.execute("""
                SELECT * FROM grn_items WHERE grn_id = %s ORDER BY grn_item_id
            """, (grn_id,))
            new_items = cursor.fetchall()

            for idx, (item_data, new_item) in enumerate(zip(grn_items_data, new_items), 1):
                accepted_qty = item_data['received_quantity'] - item_data['rejected_quantity']

                if accepted_qty <= 0:
                    print(f"  ⚠️  Item {idx}: No accepted quantity, skipping")
                    continue

                batch_number = f"BATCH-{grn_code}-{idx}"

                print(f"\n  📦 Recreating Item {idx}: Product {item_data['product_id']}")
                print(f"     Accepted: {accepted_qty}")
                if item_data['expiration_date']:
                    print(f"     📅 Expiration: {item_data['expiration_date']}")

                # ── 12a: Create product batch ──────────────────────────────────
                cursor.execute("""
                    INSERT INTO product_batches (
                        batch_number, product_id, variation_id,
                        quantity, remaining_quantity,
                        cost, price,
                        grn_id, purchase_order_id,
                        expiration_date
                    ) VALUES (
                        %s, %s, %s,
                        %s, %s,
                        %s, %s,
                        %s, %s,
                        %s
                    )
                """, (
                    batch_number,
                    item_data['product_id'],
                    item_data['variation_id'],
                    accepted_qty, accepted_qty,
                    item_data['net_unit_cost'],
                    item_data['batch_price'],
                    grn_id, purchase_order_id,
                    item_data['expiration_date'],
                ))

                batch_id = cursor.lastrowid
                print(f"     ✅ Batch created: {batch_id} ({batch_number})")

                # ── 12b: Update GRN item with batch_id ────────────────────────
                cursor.execute("""
                    UPDATE grn_items SET batch_id = %s WHERE grn_item_id = %s
                """, (batch_id, new_item['grn_item_id']))

                # ── 12c: Save discount rules with real batch_id ───────────────
                # For completed GRNs we have the batch_id immediately,
                # so save rules with both grn_item_id AND batch_id populated.
                if pbd_table_exists and pbd_has_grn_item_id:
                    rules = item_data.get('discount_rules', [])
                    if rules:
                        print(f"     💾 Saving {len(rules)} discount rule(s) for batch {batch_id}...")
                        for rule in rules:
                            cursor.execute("""
                                INSERT INTO product_batch_discounts
                                    (grn_item_id, batch_id, payment_method_id,
                                     discount_rate, discount_type, is_active)
                                VALUES
                                    (%s, %s, %s, %s, %s, %s)
                                ON DUPLICATE KEY UPDATE
                                    batch_id       = VALUES(batch_id),
                                    discount_rate  = VALUES(discount_rate),
                                    discount_type  = VALUES(discount_type),
                                    is_active      = VALUES(is_active),
                                    updated_at     = CURRENT_TIMESTAMP
                            """, (
                                new_item['grn_item_id'],
                                batch_id,
                                rule['payment_method_id'],
                                rule['discount_rate'],
                                rule['discount_type'],
                                rule['is_active'],
                            ))
                            total_discount_rules_saved += 1
                            print(f"       ✅ PM {rule['payment_method_id']}: "
                                  f"{rule['discount_rate']} {rule['discount_type']} "
                                  f"→ batch_id {batch_id}")
                    else:
                        print(f"     ℹ️  No discount rules for grn_item {new_item['grn_item_id']}")

                # ── 12d: Create warehouse stock ────────────────────────────────
                cursor.execute("""
                    INSERT INTO warehouse_stock (
                        store_id, product_id, warehouse_id, batch_id,
                        variation_id, quantity
                    ) VALUES (
                        %s, %s, %s, %s,
                        %s, %s
                    )
                """, (
                    store_id,
                    item_data['product_id'],
                    warehouse_id,
                    batch_id,
                    item_data['variation_id'],
                    accepted_qty,
                ))

                print(f"     ✅ Stock created: {accepted_qty} units in warehouse {warehouse_id}")

            print(f"\n  ✅ Stock recreation complete for {len(grn_items_data)} items")

        # ── Step 13: Update PO GRN status ────────────────────────────────────────
        print(f"\n📝 Step 13: Updating PO GRN status...")

        cursor.execute("""
            SELECT
                SUM(oi.quantity) as total_ordered,
                COALESCE((
                    SELECT SUM(gi.received_quantity - gi.rejected_quantity)
                    FROM grn g2
                    INNER JOIN grn_items gi ON g2.grn_id = gi.grn_id
                    WHERE g2.purchase_order_id = %s
                    AND g2.status != 'cancelled'
                ), 0) as total_received
            FROM order_items oi
            WHERE oi.order_id = %s
        """, (purchase_order_id, purchase_order_id))

        qty_result     = cursor.fetchone()
        total_ordered  = safe_float(qty_result['total_ordered'])
        total_received = safe_float(qty_result['total_received'])

        if total_received == 0:
            new_grn_status = 'not_received'
        elif total_received < total_ordered:
            new_grn_status = 'partial'
        else:
            new_grn_status = 'fully_received'

        cursor.execute("""
            UPDATE purchase_orders SET grn_status = %s WHERE order_id = %s
        """, (new_grn_status, purchase_order_id))

        print(f"  ✅ PO GRN status: {new_grn_status}")

        # ── Step 14: Activity log ─────────────────────────────────────────────────
        print(f"\n📝 Step 14: Creating activity log...")

        action_type = 'GRN Updated (Completed)' if is_completed else 'GRN Updated'
        description = f"GRN {grn_code} updated. "

        if warehouse_id != old_warehouse_id:
            description += f"Warehouse: {old_warehouse_id} → {warehouse_id}. "

        if supplier_id != old_supplier_id:
            description += f"Supplier: {old_supplier_id} → {supplier_id}. "

        if is_completed:
            description += "Stock reversed and recreated. "

        description += (
            f"Payment: {payment_status}, "
            f"Total: LKR {grand_total:.2f}, "
            f"Discount rules saved: {total_discount_rules_saved}"
        )

        cursor.execute("""
            INSERT INTO grn_activity_log (
                grn_id, user_id, action, description, created_at
            ) VALUES (
                %s, %s, %s, %s, NOW()
            )
        """, (grn_id, current_user, action_type, description))

        print(f"  ✅ Activity logged")

        # ── Commit ───────────────────────────────────────────────────────────────
        conn.commit()
        print(f"\n✅ Transaction committed successfully!")

        print("\n" + "=" * 80)
        print("✅ GRN UPDATE COMPLETE")
        print("=" * 80)
        print(f"✅ GRN Code             : {grn_code}")
        print(f"✅ GRN ID               : {grn_id}")
        print(f"✅ Warehouse            : {old_warehouse_id} → {warehouse_id}")
        print(f"✅ Supplier             : {old_supplier_id} → {supplier_id}")
        print(f"✅ Status               : {existing_grn['status']}")
        print(f"✅ Total Items          : {total_items}")
        print(f"✅ Items w/ expiry      : {len([i for i in grn_items_data if i['expiration_date']])}")
        print(f"✅ Discount rules saved : {total_discount_rules_saved}")
        print(f"✅ Grand Total          : LKR {grand_total:.2f}")
        print(f"✅ Payment Status       : {payment_status}")

        if is_completed:
            print(f"\n⚠️  COMPLETED GRN EDITED:")
            print(f"   📦 Old stock + old discount rules REVERSED")
            print(f"   📦 New stock + new discount rules CREATED in warehouse {warehouse_id}")
        print("=" * 80)

        return jsonify({
            'success': True,
            'message': f'GRN {grn_code} updated successfully.',
            'data': {
                'grn_id':                   grn_id,
                'grn_code':                 grn_code,
                'purchase_order_id':        purchase_order_id,
                'warehouse_id':             warehouse_id,
                'supplier_id':              supplier_id,
                'status':                   existing_grn['status'],
                'total_items':              total_items,
                'subtotal':                 subtotal,
                'order_tax':                order_tax_amount,
                'order_tax_percentage':     order_tax_percentage,
                'discount':                 order_discount_amount,
                'grand_total':              grand_total,
                'payment_status':           payment_status,
                'payment_type':             payment_type,
                'paid_amount':              actual_paid,
                'due_amount':               due_amount,
                'discount_rules_saved':     total_discount_rules_saved,
                'stock_recreated':          is_completed,
            }
        }), 200

    except ValueError as err:
        conn.rollback()
        print(f"\n❌ Validation Error: {err}")
        return jsonify({'error': str(err)}), 400

    except mysql.connector.Error as err:
        conn.rollback()
        print(f"\n❌ Database Error: {err}")
        traceback.print_exc()
        return jsonify({'error': f'Database error: {str(err)}'}), 500

    except Exception as e:
        conn.rollback()
        print(f"\n❌ Unexpected Error: {e}")
        traceback.print_exc()
        return jsonify({'error': f'Server error: {str(e)}'}), 500

    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()
        print("\n🔒 Database connection closed\n")

@grn_bp.route('/approve_grn/<int:grn_id>', methods=['POST'])
@jwt_required()
@role_required('admin', 'manager')
def approve_grn(grn_id):
    """
    APPROVE GRN - COMPLETE VERSION WITH EXPIRATION_DATE + BATCH DISCOUNT RULES

    This endpoint creates product_batches and updates stock.
    After this, products are AVAILABLE FOR SALE in POS!

    Features:
    - Creates product_batches with expiration_date
    - Updates warehouse_stock with accepted quantities
    - Updates product_batch_discounts.batch_id (was NULL during GRN creation)
    - Updates GRN status to 'completed'
    - Updates purchase_orders status
    - Logs approval activity
    - Proper transaction handling
    - Complete error handling
    """

    conn = get_db_connection()
    if not conn:
        return jsonify({'error': 'Database connection failed'}), 500

    cursor = conn.cursor(dictionary=True)

    try:
        conn.start_transaction()

        print("\n" + "=" * 80)
        print(f"🎯 APPROVING GRN {grn_id} - STOCK UPDATE BEGINS!")
        print("=" * 80)

        # ── Step 1: Get GRN details ──────────────────────────────────────────────
        print(f"\n📝 Step 1: Fetching GRN details...")
        cursor.execute("""
            SELECT g.*, u.name as received_by_name
            FROM grn g
            LEFT JOIN users u ON g.received_by = u.id
            WHERE g.grn_id = %s
        """, (grn_id,))

        grn = cursor.fetchone()

        if not grn:
            raise ValueError(f"GRN {grn_id} not found")

        if grn['status'] == 'completed':
            raise ValueError(f"GRN {grn['grn_code']} is already approved")

        if grn['status'] == 'cancelled':
            raise ValueError(f"GRN {grn['grn_code']} is cancelled and cannot be approved")

        print(f"  ✅ GRN {grn['grn_code']} found")
        print(f"  📦 Warehouse  : {grn['warehouse_id']}")
        print(f"  🚚 Supplier   : {grn['supplier_id']}")
        print(f"  💰 Grand Total: LKR {grn['grand_total']:.2f}")

        # ── Step 2: Get GRN items ────────────────────────────────────────────────
        print(f"\n📝 Step 2: Fetching GRN items...")
        cursor.execute("""
            SELECT
                gi.*,
                p.product_name,
                p.sku
            FROM grn_items gi
            LEFT JOIN products p ON gi.product_id = p.id
            WHERE gi.grn_id = %s
            ORDER BY gi.grn_item_id
        """, (grn_id,))

        items = cursor.fetchall()

        if not items:
            raise ValueError(f"No items found for GRN {grn_id}")

        items_with_expiry = len([i for i in items if i.get('expiration_date')])
        print(f"  ✅ Found {len(items)} items")
        print(f"  ✅ Items with expiration_date: {items_with_expiry}")

        # ── Step 3: Check product_batches table structure ────────────────────────
        print(f"\n📝 Step 3: Checking product_batches table structure...")
        cursor.execute("""
            SELECT COLUMN_NAME
            FROM INFORMATION_SCHEMA.COLUMNS
            WHERE TABLE_SCHEMA = DATABASE()
            AND TABLE_NAME = 'product_batches'
            AND COLUMN_NAME = 'expiration_date'
        """)
        has_expiration_column = cursor.fetchone() is not None
        print(f"  {'✅' if has_expiration_column else '⚠️ '} expiration_date column: "
              f"{'EXISTS' if has_expiration_column else 'NOT FOUND'}")

        # ── Step 4: Check product_batch_discounts table ──────────────────────────
        # Verify grn_item_id column exists (added after initial table creation)
        print(f"\n📝 Step 4: Checking product_batch_discounts table...")

        pbd_table_exists = False
        pbd_has_grn_item_id = False

        cursor.execute("""
            SELECT COUNT(*) as cnt
            FROM INFORMATION_SCHEMA.TABLES
            WHERE TABLE_SCHEMA = DATABASE()
              AND TABLE_NAME   = 'product_batch_discounts'
        """)
        pbd_table_exists = cursor.fetchone()['cnt'] > 0

        if pbd_table_exists:
            cursor.execute("""
                SELECT COUNT(*) as cnt
                FROM INFORMATION_SCHEMA.COLUMNS
                WHERE TABLE_SCHEMA = DATABASE()
                  AND TABLE_NAME   = 'product_batch_discounts'
                  AND COLUMN_NAME  = 'grn_item_id'
            """)
            pbd_has_grn_item_id = cursor.fetchone()['cnt'] > 0

        print(f"  {'✅' if pbd_table_exists else '⚠️ '} product_batch_discounts table: "
              f"{'EXISTS' if pbd_table_exists else 'NOT FOUND'}")
        print(f"  {'✅' if pbd_has_grn_item_id else '⚠️ '} grn_item_id column: "
              f"{'EXISTS' if pbd_has_grn_item_id else 'NOT FOUND — discount rules will be skipped'}")

        # ── Step 5: Create batches, update stock, link discount rules ───────────
        print(f"\n📝 Step 5: Creating batches, updating stock, linking discount rules...")

        current_user          = get_jwt_identity()
        batches_created       = 0
        items_processed       = 0
        discount_rules_linked = 0

        for idx, item in enumerate(items, 1):
            accepted_qty = item['received_quantity'] - item['rejected_quantity']

            if accepted_qty <= 0:
                print(f"  ⚠️  Item {idx}: No accepted quantity, skipping")
                continue

            items_processed += 1

            batch_number    = f"BATCH-{grn['grn_code']}-{idx}"
            batch_price     = safe_float(item.get('batch_price', item['unit_price']))
            expiration_date = item.get('expiration_date')

            print(f"\n  📦 Item {idx}:")
            print(f"     Product      : {item['product_id']} ({item.get('product_name', 'N/A')})")
            print(f"     Ordered      : {item['ordered_quantity']}, "
                  f"Received: {item['received_quantity']}, "
                  f"Rejected: {item['rejected_quantity']}")
            print(f"     Accepted Qty : {accepted_qty}")
            print(f"     Cost         : {item['unit_price']}")
            print(f"     Batch Price  : {batch_price}")
            print(f"     📅 Expiration: {expiration_date if expiration_date else 'No expiry date'}")

            # ── 5a: Create product batch ─────────────────────────────────────────
            if has_expiration_column:
                cursor.execute("""
                    INSERT INTO product_batches (
                        batch_number, product_id, variation_id,
                        quantity, remaining_quantity,
                        cost, price,
                        grn_id, purchase_order_id,
                        expiration_date
                    ) VALUES (
                        %s, %s, %s,
                        %s, %s,
                        %s, %s,
                        %s, %s,
                        %s
                    )
                """, (
                    batch_number, item['product_id'], item['variation_id'],
                    accepted_qty, accepted_qty,
                    item['net_unit_cost'], batch_price,
                    grn_id, grn['purchase_order_id'],
                    expiration_date
                ))
                print(f"     ✅ Batch created with expiration_date: {expiration_date}")
            else:
                cursor.execute("""
                    INSERT INTO product_batches (
                        batch_number, product_id, variation_id,
                        quantity, remaining_quantity,
                        cost, price,
                        grn_id, purchase_order_id
                    ) VALUES (
                        %s, %s, %s,
                        %s, %s,
                        %s, %s,
                        %s, %s
                    )
                """, (
                    batch_number, item['product_id'], item['variation_id'],
                    accepted_qty, accepted_qty,
                    item['net_unit_cost'], batch_price,
                    grn_id, grn['purchase_order_id']
                ))
                print(f"     ⚠️  Batch created without expiration_date (column not found)")

            batch_id = cursor.lastrowid
            batches_created += 1
            print(f"     ✅ Batch {batch_id} created: {batch_number}")

            # ── 5b: Update GRN item with batch_id ────────────────────────────────
            cursor.execute("""
                UPDATE grn_items
                SET batch_id = %s
                WHERE grn_item_id = %s
            """, (batch_id, item['grn_item_id']))

            # ── 5c: Link discount rules — update batch_id in product_batch_discounts
            # During create_grn, rules were saved with batch_id = NULL.
            # Now we have the real batch_id — update it.
            if pbd_table_exists and pbd_has_grn_item_id:
                cursor.execute("""
                    SELECT COUNT(*) as cnt
                    FROM product_batch_discounts
                    WHERE grn_item_id = %s
                """, (item['grn_item_id'],))

                rules_count = cursor.fetchone()['cnt']

                if rules_count > 0:
                    cursor.execute("""
                        UPDATE product_batch_discounts
                        SET batch_id   = %s,
                            updated_at = NOW()
                        WHERE grn_item_id = %s
                    """, (batch_id, item['grn_item_id']))

                    discount_rules_linked += rules_count
                    print(f"     ✅ Discount rules linked: {rules_count} rule(s) → batch_id {batch_id}")
                else:
                    print(f"     ℹ️  No discount rules found for grn_item {item['grn_item_id']}")
            else:
                print(f"     ℹ️  Discount rules skipped (table/column not available)")

            # ── 5d: Update warehouse stock ────────────────────────────────────────
            cursor.execute("""
                SELECT id, quantity
                FROM warehouse_stock
                WHERE store_id   = %s
                  AND product_id = %s
                  AND warehouse_id = %s
                  AND batch_id   = %s
            """, (grn['store_id'], item['product_id'], grn['warehouse_id'], batch_id))

            existing_stock = cursor.fetchone()

            if existing_stock:
                new_qty = existing_stock['quantity'] + accepted_qty
                cursor.execute("""
                    UPDATE warehouse_stock
                    SET quantity = %s, updated_at = NOW()
                    WHERE id = %s
                """, (new_qty, existing_stock['id']))
                print(f"     ✅ Stock updated: {existing_stock['quantity']} → {new_qty}")
            else:
                cursor.execute("""
                    INSERT INTO warehouse_stock (
                        store_id, product_id, warehouse_id, batch_id,
                        variation_id, quantity
                    ) VALUES (
                        %s, %s, %s, %s,
                        %s, %s
                    )
                """, (
                    grn['store_id'], item['product_id'],
                    grn['warehouse_id'], batch_id,
                    item['variation_id'], accepted_qty
                ))
                print(f"     ✅ Stock created: {accepted_qty} units in warehouse {grn['warehouse_id']}")

        print(f"\n  ✅ Batches created        : {batches_created}/{len(items)}")
        print(f"  ✅ Items processed         : {items_processed}")
        print(f"  ✅ Discount rules linked   : {discount_rules_linked}")

        # ── Step 6: Update GRN status ────────────────────────────────────────────
        print(f"\n📝 Step 6: Updating GRN status...")
        cursor.execute("""
            UPDATE grn
            SET status      = 'completed',
                approved_by = %s,
                approved_at = NOW()
            WHERE grn_id = %s
        """, (current_user, grn_id))
        print(f"  ✅ GRN status: pending → completed")

        # ── Step 7: Update Purchase Order ────────────────────────────────────────
        print(f"\n📝 Step 7: Updating Purchase Order...")

        cursor.execute("""
            SELECT
                SUM(oi.quantity) as total_ordered,
                COALESCE((
                    SELECT SUM(gi.received_quantity - gi.rejected_quantity)
                    FROM grn g2
                    INNER JOIN grn_items gi ON g2.grn_id = gi.grn_id
                    WHERE g2.purchase_order_id = %s
                    AND g2.status = 'completed'
                ), 0) as total_received
            FROM order_items oi
            WHERE oi.order_id = %s
        """, (grn['purchase_order_id'], grn['purchase_order_id']))

        po_qty         = cursor.fetchone()
        total_ordered  = safe_float(po_qty['total_ordered'])
        total_received = safe_float(po_qty['total_received'])

        print(f"  📊 Total Ordered  : {total_ordered}")
        print(f"  📊 Total Received : {total_received}")

        if total_received >= total_ordered:
            new_po_status  = 'Received'
            new_grn_status = 'fully_received'
        else:
            new_po_status  = 'Ordered'
            new_grn_status = 'partial'

        cursor.execute("""
            UPDATE purchase_orders
            SET status     = %s,
                grn_status = %s
            WHERE order_id = %s
        """, (new_po_status, new_grn_status, grn['purchase_order_id']))

        print(f"  ✅ PO Status    : {new_po_status}")
        print(f"  ✅ PO GRN Status: {new_grn_status}")

        # ── Step 8: Log activity ─────────────────────────────────────────────────
        print(f"\n📝 Step 8: Logging approval activity...")

        description = (
            f"GRN {grn['grn_code']} approved. "
            f"Batches: {batches_created}, "
            f"Items: {items_processed}, "
            f"Items with expiry: {items_with_expiry}, "
            f"Discount rules linked: {discount_rules_linked}, "
            f"Stock: {new_grn_status}, "
            f"Total: LKR {grn['grand_total']:.2f}"
        )

        cursor.execute("""
            INSERT INTO grn_activity_log (
                grn_id, user_id, action, description, created_at
            ) VALUES (
                %s, %s, 'GRN Approved', %s, NOW()
            )
        """, (grn_id, current_user, description))

        print(f"  ✅ Activity logged")

        # ── Commit ───────────────────────────────────────────────────────────────
        conn.commit()

        print("\n" + "=" * 80)
        print("🎉 GRN APPROVAL COMPLETE - STOCK NOW AVAILABLE!")
        print("=" * 80)
        print(f"✅ GRN {grn['grn_code']} approved successfully")
        print(f"✅ Batches created         : {batches_created}")
        print(f"✅ Items processed         : {items_processed}")
        print(f"✅ Items with expiry       : {items_with_expiry}")
        print(f"✅ Discount rules linked   : {discount_rules_linked}")
        print(f"✅ Warehouse updated       : {grn['warehouse_id']}")
        print(f"✅ PO Status               : {new_po_status} / {new_grn_status}")
        print(f"✅ Products available for sale in POS!")
        print("=" * 80)

        return jsonify({
            'success': True,
            'message': f'GRN {grn["grn_code"]} approved successfully. Stock updated!',
            'data': {
                'grn_id':                grn_id,
                'grn_code':              grn['grn_code'],
                'status':                'completed',
                'batches_created':       batches_created,
                'items_processed':       items_processed,
                'items_with_expiry':     items_with_expiry,
                'discount_rules_linked': discount_rules_linked,
                'stock_updated':         True,
                'warehouse_id':          grn['warehouse_id'],
                'po_status':             new_po_status,
                'po_grn_status':         new_grn_status,
                'grand_total':           grn['grand_total'],
            }
        }), 200

    except ValueError as err:
        conn.rollback()
        print(f"\n❌ Validation Error: {err}")
        return jsonify({'error': str(err)}), 400

    except mysql.connector.Error as err:
        conn.rollback()
        print(f"\n❌ Database Error: {err}")
        traceback.print_exc()
        return jsonify({'error': f'Database error: {str(err)}'}), 500

    except Exception as e:
        conn.rollback()
        print(f"\n❌ Unexpected Error: {e}")
        traceback.print_exc()
        return jsonify({'error': f'Server error: {str(e)}'}), 500

    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()
        print("\n🔒 Database connection closed\n")
# ============================================================================
# END OF approve_grn FUNCTION
# ============================================================================

@grn_bp.route('/get_grns', methods=['GET'])
@jwt_required()
@role_required('admin', 'manager')
def get_grns():
    """Get all GRNs with complete details including payment information"""
    
    conn = get_db_connection()
    if not conn:
        return jsonify({'error': 'Database connection failed'}), 500
    
    cursor = conn.cursor(dictionary=True)
    
    try:
        cursor.execute("""
            SELECT 
                g.grn_id, g.grn_code, g.purchase_order_id, g.grn_date,
                g.status, g.total_items, g.subtotal, g.tax, g.discount, g.grand_total,
                g.order_tax, g.payment_status, g.payment_type, g.paid_amount, g.due_amount,
                g.note, g.invoice_number, g.invoice_date,
                g.vehicle_number, g.driver_name, g.driver_contact,
                s.supplier_name, w.warehouse_name, st.store_name,
                u.name as received_by_name,
                g.created_at, g.approved_at
            FROM grn g
            LEFT JOIN suppliers s ON g.supplier_id = s.id
            LEFT JOIN warehouses w ON g.warehouse_id = w.id
            LEFT JOIN stores st ON g.store_id = st.id
            LEFT JOIN users u ON g.received_by = u.id
            ORDER BY g.grn_id DESC
        """)
        
        grns = cursor.fetchall()
        
        # Get items for each GRN
        for grn in grns:
            cursor.execute("""
                SELECT 
                    gi.*, p.product_name, p.sku,
                    pv.variation_name, pv.variation_type,
                    pb.batch_number
                FROM grn_items gi
                LEFT JOIN products p ON gi.product_id = p.id
                LEFT JOIN product_variations pv ON gi.variation_id = pv.id
                LEFT JOIN product_batches pb ON gi.batch_id = pb.batch_id
                WHERE gi.grn_id = %s
            """, (grn['grn_id'],))
            
            grn['items'] = cursor.fetchall()
        
        return jsonify({
            'success': True,
            'data': grns,
            'count': len(grns)
        }), 200
    
    except Exception as e:
        print(f"❌ Error fetching GRNs: {str(e)}")
        traceback.print_exc()
        return jsonify({'error': str(e)}), 500
    
    finally:
        cursor.close()
        conn.close()


@grn_bp.route('/get_grn/<int:grn_id>', methods=['GET'])
@jwt_required()
@role_required('admin', 'manager')
def get_grn(grn_id):
    """
    GET GRN - COMPLETE VERSION

    Returns complete GRN details including:
    - GRN header with all fields
    - Warehouse and supplier information
    - All pricing details (unit_price, batch_price)
    - Payment information
    - Expiration dates for items
    - Discount rules per item (from product_batch_discounts)
    - Activity log with all changes
    - Properly formatted dates for HTML inputs
    - Product and variation details
    - Batch information if approved

    Response Format:
    {
      "success": true,
      "data": {
        "grn_id": 123,
        "grn_code": "GRN-2026-00001",
        "warehouse_name": "Main Warehouse",
        "supplier_name": "Supplier ABC",
        "payment_status": "Paid",
        "items": [
          {
            "grn_item_id": 1,
            "product_id": 10,
            "product_name": "Product Name",
            "unit_price": 50.00,
            "batch_price": 55.00,
            "expiration_date": "2026-12-31",
            "discount_rules": [
              {
                "payment_method_id": 1,
                "discount_type": "percent",
                "discount_rate": 5.00,
                "is_active": 1
              }
            ],
            ...
          }
        ],
        "activity_log": [
          {
            "action": "GRN Approved",
            "description": "...",
            "user_name": "Admin User",
            "created_at": "2026-02-12 10:30:00"
          }
        ]
      }
    }
    """

    conn = get_db_connection()
    if not conn:
        return jsonify({'error': 'Database connection failed'}), 500

    cursor = conn.cursor(dictionary=True)

    try:
        print(f"\n📦 Fetching GRN {grn_id}...")

        # ── Step 1: Get GRN header ───────────────────────────────────────────────
        print(f"\n📝 Step 1: Fetching GRN header...")
        cursor.execute("""
            SELECT
                g.*,
                s.supplier_name, s.supplier_email,
                s.supplier_contact, s.supplier_address,
                w.warehouse_name,
                st.store_name, st.address as store_address,
                u.name  as received_by_name,
                u2.name as created_by_name,
                u3.name as approved_by_name
            FROM grn g
            LEFT JOIN suppliers  s  ON g.supplier_id  = s.id
            LEFT JOIN warehouses w  ON g.warehouse_id = w.id
            LEFT JOIN stores     st ON g.store_id     = st.id
            LEFT JOIN users      u  ON g.received_by  = u.id
            LEFT JOIN users      u2 ON g.created_by   = u2.id
            LEFT JOIN users      u3 ON g.approved_by  = u3.id
            WHERE g.grn_id = %s
        """, (grn_id,))

        grn = cursor.fetchone()

        if not grn:
            print(f"❌ GRN {grn_id} not found")
            return jsonify({'error': 'GRN not found'}), 404

        print(f"  ✅ GRN {grn['grn_code']} found")
        print(f"     Status     : {grn['status']}")
        print(f"     Grand Total: LKR {grn['grand_total']:.2f}")

        # ── Step 2: Get GRN items ────────────────────────────────────────────────
        print(f"\n📝 Step 2: Fetching GRN items...")
        cursor.execute("""
            SELECT
                gi.*,
                p.product_name, p.sku,
                pv.variation_name, pv.variation_type, pv.variation_sku,
                pb.batch_number, pb.batch_id
            FROM grn_items gi
            LEFT JOIN products           p  ON gi.product_id   = p.id
            LEFT JOIN product_variations pv ON gi.variation_id = pv.id
            LEFT JOIN product_batches    pb ON gi.batch_id     = pb.batch_id
            WHERE gi.grn_id = %s
            ORDER BY gi.grn_item_id
        """, (grn_id,))

        items = cursor.fetchall()
        print(f"  ✅ Found {len(items)} items")

        # ── Step 2b: Fetch discount rules per item ───────────────────────────────
        print(f"\n📝 Step 2b: Fetching discount rules per item...")
        total_rules_loaded = 0

        for item in items:
            cursor.execute("""
                SELECT
                    payment_method_id,
                    discount_type,
                    discount_rate,
                    is_active
                FROM product_batch_discounts
                WHERE grn_item_id = %s
                ORDER BY payment_method_id
            """, (item['grn_item_id'],))

            rules = cursor.fetchall()
            item['discount_rules'] = [
                {
                    'payment_method_id': r['payment_method_id'],
                    'discount_type':     r['discount_type'],
                    'discount_rate':     float(r['discount_rate']),
                    'is_active':         r['is_active']
                }
                for r in rules
            ]
            total_rules_loaded += len(rules)

        print(f"  ✅ Discount rules loaded: {total_rules_loaded} total across {len(items)} items")

        # ── Step 3: Format item dates ────────────────────────────────────────────
        print(f"\n📝 Step 3: Formatting item dates...")

        items_with_expiration = 0

        for item in items:
            if item.get('expiration_date'):
                if isinstance(item['expiration_date'], (datetime, date)):
                    item['expiration_date'] = item['expiration_date'].strftime('%Y-%m-%d')
                items_with_expiration += 1

        print(f"  ✅ Items with expiration: {items_with_expiration}")

        grn['items'] = items

        # ── Step 4: Get activity log ─────────────────────────────────────────────
        print(f"\n📝 Step 4: Fetching activity log...")
        cursor.execute("""
            SELECT
                al.*,
                u.name as user_name
            FROM grn_activity_log al
            LEFT JOIN users u ON al.user_id = u.id
            WHERE al.grn_id = %s
            ORDER BY al.created_at DESC
        """, (grn_id,))

        activity_log = cursor.fetchall()
        print(f"  ✅ Found {len(activity_log)} activity entries")

        grn['activity_log'] = activity_log

        # ── Step 5: Format header dates for HTML inputs ──────────────────────────
        print(f"\n📝 Step 5: Formatting header dates...")

        if grn.get('grn_date'):
            if isinstance(grn['grn_date'], (datetime, date)):
                grn['grn_date'] = grn['grn_date'].strftime('%Y-%m-%d')
                print(f"  ✅ grn_date formatted: {grn['grn_date']}")

        if grn.get('invoice_date'):
            if isinstance(grn['invoice_date'], (datetime, date)):
                grn['invoice_date'] = grn['invoice_date'].strftime('%Y-%m-%d')
                print(f"  ✅ invoice_date formatted: {grn['invoice_date']}")

        if grn.get('created_at'):
            if isinstance(grn['created_at'], datetime):
                grn['created_at'] = grn['created_at'].strftime('%Y-%m-%d %H:%M:%S')

        if grn.get('updated_at'):
            if isinstance(grn['updated_at'], datetime):
                grn['updated_at'] = grn['updated_at'].strftime('%Y-%m-%d %H:%M:%S')

        if grn.get('approved_at'):
            if isinstance(grn['approved_at'], datetime):
                grn['approved_at'] = grn['approved_at'].strftime('%Y-%m-%d %H:%M:%S')

        # ── Step 6: Calculate summary stats ─────────────────────────────────────
        print(f"\n📝 Step 6: Calculating summary statistics...")

        total_ordered  = sum(safe_float(i.get('ordered_quantity',  0)) for i in items)
        total_received = sum(safe_float(i.get('received_quantity', 0)) for i in items)
        total_rejected = sum(safe_float(i.get('rejected_quantity', 0)) for i in items)
        total_accepted = total_received - total_rejected

        grn['summary'] = {
            'total_items':           len(items),
            'items_with_expiration': items_with_expiration,
            'total_ordered':         total_ordered,
            'total_received':        total_received,
            'total_rejected':        total_rejected,
            'total_accepted':        total_accepted,
            'total_activities':      len(activity_log),
            'total_discount_rules':  total_rules_loaded,
        }

        print(f"  ✅ Summary calculated:")
        print(f"     Total Items        : {len(items)}")
        print(f"     Total Ordered      : {total_ordered}")
        print(f"     Total Received     : {total_received}")
        print(f"     Total Rejected     : {total_rejected}")
        print(f"     Total Accepted     : {total_accepted}")
        print(f"     Total Discount Rules: {total_rules_loaded}")

        print(f"\n✅ GRN {grn['grn_code']} fetched successfully")
        print(f"   Status               : {grn['status']}")
        print(f"   Items                : {len(items)}")
        print(f"   Items with expiration: {items_with_expiration}")
        print(f"   Discount rules loaded: {total_rules_loaded}")
        print(f"   Total                : LKR {grn['grand_total']:.2f}")

        return jsonify({
            'success': True,
            'data':    grn,
        }), 200

    except Exception as e:
        print(f"❌ Error fetching GRN {grn_id}: {str(e)}")
        traceback.print_exc()
        return jsonify({'error': str(e)}), 500

    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()
        print("\n🔒 Database connection closed\n")
@grn_bp.route('/cancel_grn/<int:grn_id>', methods=['POST'])
@jwt_required()
@role_required('admin', 'manager')
def cancel_grn(grn_id):
    """Cancel a pending GRN"""
    
    conn = get_db_connection()
    if not conn:
        return jsonify({'error': 'Database connection failed'}), 500
    
    cursor = conn.cursor(dictionary=True)
    
    try:
        data = request.get_json() or {}
        reason = data.get('reason', '').strip()
        
        conn.start_transaction()
        
        # Get GRN
        cursor.execute("SELECT * FROM grn WHERE grn_id = %s", (grn_id,))
        grn = cursor.fetchone()
        
        if not grn:
            raise ValueError(f"GRN {grn_id} not found")
        
        if grn['status'] == 'completed':
            raise ValueError("Cannot cancel an approved GRN")
        
        if grn['status'] == 'cancelled':
            raise ValueError("GRN is already cancelled")
        
        # Update GRN status
        cursor.execute("""
            UPDATE grn
            SET status = 'cancelled'
            WHERE grn_id = %s
        """, (grn_id,))
        
        # Log activity
        current_user = get_jwt_identity()
        cursor.execute("""
            INSERT INTO grn_activity_log (
                grn_id, user_id, action, description, created_at
            ) VALUES (
                %s, %s, 'GRN Cancelled', %s, NOW()
            )
        """, (grn_id, current_user, f"Reason: {reason}" if reason else "GRN cancelled"))
        
        # Update PO GRN status
        cursor.execute("""
            SELECT 
                SUM(oi.quantity) as total_ordered,
                COALESCE((
                    SELECT SUM(gi.received_quantity - gi.rejected_quantity)
                    FROM grn g2
                    INNER JOIN grn_items gi ON g2.grn_id = gi.grn_id
                    WHERE g2.purchase_order_id = %s
                    AND g2.status != 'cancelled'
                ), 0) as total_received
            FROM order_items oi
            WHERE oi.order_id = %s
        """, (grn['purchase_order_id'], grn['purchase_order_id']))
        
        po_qty = cursor.fetchone()
        total_ordered = safe_float(po_qty['total_ordered'])
        total_received = safe_float(po_qty['total_received'])
        
        if total_received == 0:
            new_grn_status = 'not_received'
        elif total_received < total_ordered:
            new_grn_status = 'partial'
        else:
            new_grn_status = 'fully_received'
        
        cursor.execute("""
            UPDATE purchase_orders
            SET grn_status = %s
            WHERE order_id = %s
        """, (new_grn_status, grn['purchase_order_id']))
        
        conn.commit()
        
        return jsonify({
            'success': True,
            'message': f'GRN {grn["grn_code"]} cancelled successfully'
        }), 200
    
    except ValueError as err:
        conn.rollback()
        return jsonify({'error': str(err)}), 400
    
    except Exception as e:
        conn.rollback()
        print(f"❌ Error cancelling GRN: {str(e)}")
        traceback.print_exc()
        return jsonify({'error': str(e)}), 500
    
    finally:
        cursor.close()
        conn.close()


@grn_bp.route('/delete_grn/<int:grn_id>', methods=['DELETE'])
@jwt_required()
@role_required('admin', 'manager')
def delete_grn(grn_id):
    """Delete a pending GRN"""
    
    conn = get_db_connection()
    if not conn:
        return jsonify({'error': 'Database connection failed'}), 500
    
    cursor = conn.cursor(dictionary=True)
    
    try:
        data = request.get_json() or {}
        reason = data.get('reason', '').strip()
        
        if not reason:
            return jsonify({'error': 'Reason for deletion is required'}), 400
        
        conn.start_transaction()
        
        # Get GRN
        cursor.execute("SELECT * FROM grn WHERE grn_id = %s", (grn_id,))
        grn = cursor.fetchone()
        
        if not grn:
            raise ValueError(f"GRN {grn_id} not found")
        
        if grn['status'] == 'completed':
            raise ValueError("Cannot delete an approved GRN. Stock has already been updated.")
        
        print(f"\n🗑️ Deleting GRN {grn['grn_code']}...")
        print(f"   Reason: {reason}")
        
        # Log deletion activity before deleting
        current_user = get_jwt_identity()
        cursor.execute("""
            INSERT INTO grn_activity_log (
                grn_id, user_id, action, description, created_at
            ) VALUES (
                %s, %s, 'GRN Deleted', %s, NOW()
            )
        """, (grn_id, current_user, f"GRN deleted. Reason: {reason}"))
        
        # Delete GRN items first (foreign key constraint)
        cursor.execute("DELETE FROM grn_items WHERE grn_id = %s", (grn_id,))
        print(f"  ✅ GRN items deleted")
        
        # Delete GRN activity log
        cursor.execute("DELETE FROM grn_activity_log WHERE grn_id = %s", (grn_id,))
        print(f"  ✅ Activity log deleted")
        
        # Delete GRN
        cursor.execute("DELETE FROM grn WHERE grn_id = %s", (grn_id,))
        print(f"  ✅ GRN deleted")
        
        # Update PO GRN status
        cursor.execute("""
            SELECT 
                SUM(oi.quantity) as total_ordered,
                COALESCE((
                    SELECT SUM(gi.received_quantity - gi.rejected_quantity)
                    FROM grn g2
                    INNER JOIN grn_items gi ON g2.grn_id = gi.grn_id
                    WHERE g2.purchase_order_id = %s
                    AND g2.status != 'cancelled'
                ), 0) as total_received
            FROM order_items oi
            WHERE oi.order_id = %s
        """, (grn['purchase_order_id'], grn['purchase_order_id']))
        
        po_qty = cursor.fetchone()
        total_ordered = safe_float(po_qty['total_ordered'])
        total_received = safe_float(po_qty['total_received'])
        
        if total_received == 0:
            new_grn_status = 'not_received'
        elif total_received < total_ordered:
            new_grn_status = 'partial'
        else:
            new_grn_status = 'fully_received'
        
        cursor.execute("""
            UPDATE purchase_orders
            SET grn_status = %s
            WHERE order_id = %s
        """, (new_grn_status, grn['purchase_order_id']))
        
        print(f"  ✅ PO GRN status updated: {new_grn_status}")
        
        conn.commit()
        
        print(f"✅ GRN {grn['grn_code']} deleted successfully")
        
        return jsonify({
            'success': True,
            'message': f'GRN {grn["grn_code"]} deleted successfully'
        }), 200
    
    except ValueError as err:
        conn.rollback()
        print(f"❌ Validation Error: {err}")
        return jsonify({'error': str(err)}), 400
    
    except Exception as e:
        conn.rollback()
        print(f"❌ Error deleting GRN: {str(e)}")
        traceback.print_exc()
        return jsonify({'error': str(e)}), 500
    
    finally:
        cursor.close()
        conn.close()
        
        
@grn_bp.route('/get_grn_attachments/<int:grn_id>', methods=['GET'])
@jwt_required()
@role_required('admin', 'manager')
def get_grn_attachments(grn_id):
    """Get all attachments for a GRN"""
    
    conn = get_db_connection()
    if not conn:
        return jsonify({'error': 'Database connection failed'}), 500
    
    cursor = conn.cursor(dictionary=True)
    
    try:
        cursor.execute("""
            SELECT 
                ga.*,
                u.name as uploaded_by_name
            FROM grn_attachments ga
            LEFT JOIN users u ON ga.uploaded_by = u.id
            WHERE ga.grn_id = %s
            ORDER BY ga.uploaded_at DESC
        """, (grn_id,))
        
        attachments = cursor.fetchall()
        
        # ✅ FORMAT DATETIME OBJECTS FOR JSON SERIALIZATION
        for attachment in attachments:
            if attachment.get('uploaded_at'):
                # Convert datetime to string format
                if isinstance(attachment['uploaded_at'], datetime):
                    attachment['uploaded_at'] = attachment['uploaded_at'].strftime('%Y-%m-%d %H:%M:%S')
        
        return jsonify({
            'success': True,
            'data': attachments,
            'count': len(attachments)
        }), 200
        
    except Exception as e:
        print(f"❌ Error fetching attachments: {str(e)}")
        import traceback
        traceback.print_exc()
        return jsonify({'error': str(e)}), 500
    
    finally:
        cursor.close()
        conn.close()

@grn_bp.route('/upload_grn_attachments', methods=['POST'])
@jwt_required()
@role_required('admin', 'manager')
def upload_grn_attachments():
    """Upload attachments for GRN"""
    
    if 'files' not in request.files:
        return jsonify({'error': 'No files provided'}), 400
    
    files = request.files.getlist('files')
    grn_id = request.form.get('grn_id')
    attachment_type = request.form.get('attachment_type', 'other')
    
    if not grn_id:
        return jsonify({'error': 'GRN ID is required'}), 400
    
    current_user = get_jwt_identity()
    
    conn = get_db_connection()
    cursor = conn.cursor(dictionary=True)
    
    UPLOAD_FOLDER = 'uploads/grn_attachments'
    ALLOWED_EXTENSIONS = {'pdf', 'jpg', 'jpeg', 'png', 'doc', 'docx'}
    
    def allowed_file(filename):
        return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
    
    uploaded_files = []
    
    try:
        # Verify GRN exists
        cursor.execute("SELECT grn_id FROM grn WHERE grn_id = %s", (grn_id,))
        if not cursor.fetchone():
            return jsonify({'error': 'GRN not found'}), 404
        
        # Create upload directory
        grn_upload_dir = os.path.join(UPLOAD_FOLDER, str(grn_id))
        os.makedirs(grn_upload_dir, exist_ok=True)
        
        for file in files:
            if file and allowed_file(file.filename):
                # Secure filename
                from werkzeug.utils import secure_filename
                original_filename = secure_filename(file.filename)
                
                # Generate unique filename
                timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
                filename = f"{timestamp}_{original_filename}"
                
                # Save file
                file_path = os.path.join(grn_upload_dir, filename)
                file.save(file_path)
                
                # Get file size and type
                file_size = os.path.getsize(file_path)
                file_type = file.content_type
                
                # Insert into database
                cursor.execute("""
                    INSERT INTO grn_attachments (
                        grn_id, file_name, file_path, file_type, file_size,
                        attachment_type, uploaded_by, uploaded_at
                    ) VALUES (%s, %s, %s, %s, %s, %s, %s, NOW())
                """, (
                    grn_id, original_filename, file_path, file_type, file_size,
                    attachment_type, current_user
                ))
                
                attachment_id = cursor.lastrowid
                
                uploaded_files.append({
                    'attachment_id': attachment_id,
                    'file_name': original_filename,
                    'file_size': file_size,
                    'file_type': file_type,
                    'attachment_type': attachment_type
                })
        
        conn.commit()
        
        return jsonify({
            'success': True,
            'message': f'{len(uploaded_files)} file(s) uploaded successfully',
            'files': uploaded_files
        }), 200
        
    except Exception as e:
        conn.rollback()
        print(f"❌ Upload error: {str(e)}")
        return jsonify({'error': str(e)}), 500
    
    finally:
        cursor.close()
        conn.close()


@grn_bp.route('/view_grn_attachment/<int:attachment_id>', methods=['GET'])
def view_grn_attachment(attachment_id):
    """
    ✅ FIXED: View/open attachment in browser - WITH TOKEN AUTH
    """
    
    print(f"\n👁️ VIEW ATTACHMENT REQUEST: ID {attachment_id}")
    
    # Get token from query parameter OR header
    token = request.args.get('token') or request.headers.get('Authorization', '').replace('Bearer ', '')
    
    if not token:
        print("❌ No token provided")
        return jsonify({
            'error': 'authorization_required',
            'message': 'Missing authentication token'
        }), 401
    
    # Verify token manually
    try:
        from flask_jwt_extended import decode_token
        decoded = decode_token(token)
        user_id = decoded.get('sub')
        user_role = decoded.get('role')
        
        print(f"✅ Token valid - User: {user_id}, Role: {user_role}")
        
        # Check role
        if user_role not in ['admin', 'manager']:
            print(f"❌ Unauthorized role: {user_role}")
            return jsonify({
                'error': 'forbidden',
                'message': 'Insufficient permissions'
            }), 403
            
    except Exception as e:
        print(f"❌ Token verification failed: {str(e)}")
        return jsonify({
            'error': 'invalid_token',
            'message': 'Invalid or expired token'
        }), 401
    
    # Get attachment
    conn = get_db_connection()
    if not conn:
        return jsonify({'error': 'Database connection failed'}), 500
    
    cursor = conn.cursor(dictionary=True)
    
    try:
        cursor.execute("""
            SELECT * FROM grn_attachments 
            WHERE attachment_id = %s
        """, (attachment_id,))
        
        attachment = cursor.fetchone()
        
        if not attachment:
            print(f"❌ Attachment {attachment_id} not found")
            return jsonify({'error': 'Attachment not found'}), 404
        
        file_path = attachment['file_path']
        print(f"📄 File path: {file_path}")
        
        if not os.path.exists(file_path):
            print(f"❌ File not found on disk: {file_path}")
            return jsonify({'error': 'File not found on server'}), 404
        
        print(f"✅ Sending file: {attachment['file_name']}")
        
        return send_file(
            file_path,
            mimetype=attachment['file_type'],
            as_attachment=False,  # Display in browser
            download_name=attachment['file_name']
        )
        
    except Exception as e:
        print(f"❌ Error: {str(e)}")
        import traceback
        traceback.print_exc()
        return jsonify({'error': str(e)}), 500
    
    finally:
        cursor.close()
        conn.close()


@grn_bp.route('/download_grn_attachment/<int:attachment_id>', methods=['GET'])
@jwt_required()
@role_required('admin', 'manager')
def download_grn_attachment(attachment_id):
    """Download attachment"""
    
    conn = get_db_connection()
    cursor = conn.cursor(dictionary=True)
    
    try:
        cursor.execute("""
            SELECT * FROM grn_attachments 
            WHERE attachment_id = %s
        """, (attachment_id,))
        
        attachment = cursor.fetchone()
        
        if not attachment:
            return jsonify({'error': 'Attachment not found'}), 404
        
        from flask import send_file
        
        if not os.path.exists(attachment['file_path']):
            return jsonify({'error': 'File not found on disk'}), 404
        
        return send_file(
            attachment['file_path'],
            mimetype=attachment['file_type'],
            as_attachment=True,  # Force download
            download_name=attachment['file_name']
        )
        
    except Exception as e:
        return jsonify({'error': str(e)}), 500
    
    finally:
        cursor.close()
        conn.close()


@grn_bp.route('/delete_grn_attachment/<int:attachment_id>', methods=['DELETE'])
@jwt_required()
@role_required('admin', 'manager')
def delete_grn_attachment(attachment_id):
    """Delete attachment"""
    
    conn = get_db_connection()
    cursor = conn.cursor(dictionary=True)
    
    try:
        # Get attachment info
        cursor.execute("""
            SELECT * FROM grn_attachments 
            WHERE attachment_id = %s
        """, (attachment_id,))
        
        attachment = cursor.fetchone()
        
        if not attachment:
            return jsonify({'error': 'Attachment not found'}), 404
        
        # Delete file from disk
        if os.path.exists(attachment['file_path']):
            os.remove(attachment['file_path'])
        
        # Delete from database
        cursor.execute("""
            DELETE FROM grn_attachments 
            WHERE attachment_id = %s
        """, (attachment_id,))
        
        conn.commit()
        
        return jsonify({
            'success': True,
            'message': 'Attachment deleted successfully'
        }), 200
        
    except Exception as e:
        conn.rollback()
        return jsonify({'error': str(e)}), 500
    
    finally:
        cursor.close()
        conn.close()
        
@grn_bp.route('/get_grn_by_po/<int:purchase_order_id>', methods=['GET'])
@jwt_required()
@role_required('admin', 'manager')
def get_grn_by_po(purchase_order_id):
    """
    Get all GRNs for a specific purchase order.

    Used by continue_grn.html to calculate:
    - Total already received quantity
    - Remaining quantity per item

    Returns:
    {
        "success": true,
        "grns": [
            {
                "grn_id": 1,
                "grn_code": "GRN-001",
                "status": "completed",
                "items": [...]
            }
        ],
        "count": 2
    }
    """

    conn = get_db_connection()
    if not conn:
        return jsonify({'error': 'Database connection failed'}), 500

    cursor = conn.cursor(dictionary=True)

    try:
        print(f"\n📋 Fetching GRNs for Purchase Order: {purchase_order_id}")

        # ── Step 1: Get all GRNs for this PO ────────────────────────────────────
        cursor.execute("""
            SELECT
                grn_id, grn_code, purchase_order_id, status,
                grn_date, received_by, invoice_number, invoice_date,
                vehicle_number, driver_name, driver_contact,
                total_items, subtotal, tax, discount, grand_total,
                order_tax, payment_status, payment_type, paid_amount, due_amount,
                note, approved_by, approved_at, created_at, updated_at
            FROM grn
            WHERE purchase_order_id = %s
            ORDER BY grn_id DESC
        """, (purchase_order_id,))

        grns = cursor.fetchall()

        if not grns:
            print(f"ℹ️  No GRNs found for purchase order {purchase_order_id}")
            return jsonify({
                'success': True,
                'grns':    [],
                'count':   0,
            }), 200

        # ── Step 2: Get items for each GRN ──────────────────────────────────────
        for grn in grns:
            # ✅ FIX: column is `purchase_unit`, not `unit`
            cursor.execute("""
                SELECT
                    grn_item_id, grn_id, product_id, variation_id,
                    ordered_quantity, received_quantity, rejected_quantity,
                    unit_price, batch_price,
                    discount, discount_type,
                    tax, tax_type,
                    net_unit_cost, subtotal,
                    quality_check,
                    purchase_unit,
                    note, batch_id,
                    expiration_date,
                    created_at
                FROM grn_items
                WHERE grn_id = %s
                ORDER BY grn_item_id
            """, (grn['grn_id'],))

            items = cursor.fetchall()

            # Convert Decimal / date values for JSON serialisation
            grn['items'] = []
            for item in items:
                grn['items'].append({
                    'grn_item_id':       item['grn_item_id'],
                    'grn_id':            item['grn_id'],
                    'product_id':        item['product_id'],
                    'variation_id':      item['variation_id'],
                    'ordered_quantity':  float(item['ordered_quantity']  or 0),
                    'received_quantity': float(item['received_quantity'] or 0),
                    'rejected_quantity': float(item['rejected_quantity'] or 0),
                    'unit_price':        float(item['unit_price']        or 0),
                    'batch_price':       float(item['batch_price']       or 0),
                    'discount':          float(item['discount']          or 0),
                    'discount_type':     item['discount_type'],
                    'tax':               float(item['tax']               or 0),
                    'tax_type':          item['tax_type'],
                    'net_unit_cost':     float(item['net_unit_cost']     or 0),
                    'subtotal':          float(item['subtotal']          or 0),
                    'quality_check':     item['quality_check'],
                    'purchase_unit':     item['purchase_unit'],
                    'note':              item['note'],
                    'batch_id':          item['batch_id'],
                    'expiration_date': (
                        item['expiration_date'].strftime('%Y-%m-%d')
                        if item['expiration_date'] else None
                    ),
                    'created_at': (
                        item['created_at'].strftime('%Y-%m-%d %H:%M:%S')
                        if item['created_at'] else None
                    ),
                })

            # ── Format GRN-level date/numeric fields ─────────────────────────────
            if grn.get('grn_date'):
                grn['grn_date'] = grn['grn_date'].strftime('%Y-%m-%d')

            if grn.get('invoice_date'):
                grn['invoice_date'] = grn['invoice_date'].strftime('%Y-%m-%d')

            if grn.get('approved_at'):
                grn['approved_at'] = grn['approved_at'].strftime('%Y-%m-%d %H:%M:%S')

            if grn.get('created_at'):
                grn['created_at'] = grn['created_at'].strftime('%Y-%m-%d %H:%M:%S')

            if grn.get('updated_at'):
                grn['updated_at'] = grn['updated_at'].strftime('%Y-%m-%d %H:%M:%S')

            grn['total_items']  = int(grn['total_items']  or 0)
            grn['subtotal']     = float(grn['subtotal']    or 0)
            grn['tax']          = float(grn['tax']         or 0)
            grn['discount']     = float(grn['discount']    or 0)
            grn['grand_total']  = float(grn['grand_total'] or 0)
            grn['order_tax']    = float(grn['order_tax']   or 0)
            grn['paid_amount']  = float(grn['paid_amount'] or 0)
            grn['due_amount']   = float(grn['due_amount']  or 0)

        print(f"✅ Found {len(grns)} GRN(s) for purchase order {purchase_order_id}")
        for grn in grns:
            print(f"   - GRN {grn['grn_code']}: {grn['status']}, {len(grn['items'])} items")

        return jsonify({
            'success': True,
            'grns':    grns,
            'count':   len(grns),
        }), 200

    except Exception as e:
        print(f"❌ Error in get_grn_by_po: {str(e)}")
        traceback.print_exc()
        return jsonify({'error': str(e)}), 500

    finally:
        cursor.close()
        conn.close()
        print("🔒 Database connection closed\n")
