from operator import is_
import random
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle, Circle
import time
import math

random.seed(42)
np.random.seed(42)

# Functions provided by the professor
def random_item():
    weight = round(random.random() * 50, 2)
    value = round(np.random.normal(weight, 10, 1)[0], 2)
    while value < 0:
        value = round(np.random.normal(weight, 10, 1)[0], 2)
    return (weight, float(value))

def random_data(N):
    items = [random_item() for i in range(N)]
    capacity = max(1, round(random.random() * sum(item[0] for item in items) * 0.7))
    return (items, capacity)

def random_solution(N):
    sol = set()
    for i in range(N):
        if random.randint(0, 1) == 0:
            sol.add(i)
    return sol

def greedy_solution(items, capacity):
    by_density = sorted(enumerate(items), key=lambda e: e[1][1]/e[1][0], reverse=True)
    remaining_capacity = capacity
    sol = set()
    for (index, item) in by_density:
        if item[0] <= remaining_capacity:
            sol.add(index)
            remaining_capacity -= item[0]  # Subtract weight, not value
    return sol

def tweak(sol, N):
    index = random.randint(0, N-1)  # Choose a random item
    new_sol = set(sol)
    if index in new_sol:
        new_sol.remove(index)
    else:
        new_sol.add(index)
    return new_sol

def score(sol, items, capacity):
    total_weight = sum(items[ind][0] for ind in sol)
    if total_weight > capacity:
        return 0  # Invalid solution
    return sum(items[ind][1] for ind in sol)

# Additional helper functions
def get_total_weight(sol, items):
    return sum(items[ind][0] for ind in sol)

def is_valid_solution(sol, items, capacity):
    return get_total_weight(sol, items) <= capacity

# Initialize the figure for animation
fig = plt.figure(figsize=(16, 9))
ax1 = fig.add_subplot(121)  # Left subplot for knapsack
ax2 = fig.add_subplot(122)  # Right subplot for convergence plot
plt.ion()  # Turn on interactive mode

# Lists to store convergence data
iterations_list = []
restart_markers = []
best_scores_list = []
solution_scores_list = []

# Function to update the knapsack visualization
def show_knapsack(sol, items, capacity, current_score, best_score, restart_num, iteration, valid):
    # Clear and set up the knapsack visualization
    ax1.clear()
    ax1.set_xlim([0, 10])
    ax1.set_ylim([0, 10])
    
    # Calculate total weight
    total_weight = get_total_weight(sol, items)
    
    # Draw capacity bar at the bottom
    capacity_bar_height = 0.3
    capacity_width = 8
    
    # Background bar (total capacity)
    capacity_bg = Rectangle((1, 0.3), capacity_width, capacity_bar_height, 
                            fill=True, color='lightgray')
    ax1.add_patch(capacity_bg)
    
    # Filled bar (current usage)
    usage_percent = min(1.0, total_weight / capacity)
    usage_width = capacity_width * usage_percent
    usage_color = 'green' if usage_percent <= 1.0 else 'red'
    usage_bar = Rectangle((1, 0.3), usage_width, capacity_bar_height, 
                          fill=True, color=usage_color)
    ax1.add_patch(usage_bar)
    
    # Add capacity text
    ax1.text(1 + capacity_width/2, 0.15, f'Weight: {total_weight:.2f} / {capacity:.2f}',
             horizontalalignment='center')
    
    # Calculate grid dimensions based on number of items
    N = len(items)
    grid_cols = min(100, int(math.ceil(math.sqrt(N))))
    grid_rows = int(math.ceil(N / grid_cols))
    
    # Define display area dimensions
    display_width = 8
    display_height = 8
    
    # Calculate square size to fit the grid in the display area
    square_size_x = display_width / grid_cols
    square_size_y = display_height / grid_rows
    square_size = min(square_size_x, square_size_y)  # Use the smaller dimension to ensure squares fit
    
    # Center the grid in the display area
    start_x = 1
    start_y = 1
    
    # Draw items as squares in a grid
    for i in range(min(N, 10000)):  # Limit to 10,000 items max
        # Calculate grid position
        row = i // grid_cols
        col = i % grid_cols
        
        # Calculate position (squares touch with no padding)
        x = start_x + col * square_size
        y = start_y + row * square_size
        
        # Determine color based on whether item is in solution
        if i in sol:
            color = 'blue'  # IN the solution
        else:
            color = 'lightgray'  # OUT of the solution
        
        # Draw the square (no text)
        square = Rectangle((x, y), square_size, square_size, 
                          fill=True, color=color, linewidth=0)
        ax1.add_patch(square)
    
    # Add title and information text
    ax1.set_title("Knapsack Problem Visualization")
    
    # Add information text
    status_color = 'green' if valid else 'red'
    info_text = (
        f"Restart: {restart_num}, Iteration: {iteration}\n"
        f"Status: {'Valid' if valid else 'Invalid (Exceeds Capacity)'}\n"
        f"Current Value: {current_score:.2f}\n"
        f"Best Value: {best_score:.2f}\n"
        f"Items in Knapsack: {len(sol)} / {len(items)}"
    )
    
    ax1.text(0.02, 0.98, info_text, transform=ax1.transAxes, 
             fontsize=10, verticalalignment='top', 
             bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))
    
    # Create a legend for the grid
    in_square = Rectangle((0, 0), 1, 1, color='blue', alpha=0.7)
    out_square = Rectangle((0, 0), 1, 1, color='lightgray', alpha=0.7)
    ax1.legend([in_square, out_square], ['In Knapsack', 'Not in Knapsack'], 
              loc='upper right')
    
    # Update the convergence plot
    ax2.clear()
    ax2.set_title("Optimization Progress")
    ax2.set_xlabel("Iterations")
    ax2.set_ylabel("Value")
    
    # Plot all scores and best scores
    if iterations_list:
        # Plot solution scores (valid and invalid)
        ax2.scatter(iterations_list, solution_scores_list, s=20, color='blue', alpha=0.5, 
                   label='Evaluated Solutions')
        
        # Plot best score line
        ax2.plot(iterations_list, best_scores_list, 'g-', linewidth=2, label='Best Valid Solution')
        
        # Plot restart markers
        for marker in restart_markers:
            ax2.axvline(x=marker, color='r', linestyle='--', alpha=0.5)
        
        ax2.legend(loc='lower right')
    
    # Draw and update the figure
    fig.tight_layout()
    fig.canvas.draw()
    fig.canvas.flush_events()
    plt.pause(0.000001)

# Hill climbing with random restarts
def hill_climbing_with_restarts(items, capacity, max_failures, use_greedy=True):
    global iterations_list, solution_scores_list, restart_markers, best_scores_list
    
    # Reset tracking lists
    iterations_list = []
    solution_scores_list = []
    restart_markers = [0]  # Start of first restart
    best_scores_list = []
    
    N = len(items)
    best_sol = None
    best_value = 0
    total_iterations = 0
    restart = 0
    while True:
        restart += 1
        print(f"Starting restart {restart}")
        
        if use_greedy:
            sol = greedy_solution(items, capacity)
        else:
            sol = []
            
        value = score(sol, items, capacity)
        valid = is_valid_solution(sol, items, capacity)
        
        # Update visualization for the start of this restart
        if restart > 1:
            restart_markers.append(total_iterations)
        
        # Add initial point to tracking
        iterations_list.append(total_iterations)
        solution_scores_list.append(value if valid else 0)
        
        # Update best if valid
        if valid and (best_sol is None or value > best_value):
            best_sol = sol.copy()
            best_value = value
        
        # Add to best scores list
        best_scores_list.append(best_value)
        
        # Show initial state
        show_knapsack(sol, items, capacity, value, best_value, restart, 0, valid)
        # time.sleep(0.2)  # Pause to see initial state
        
        failures = 0
        local_iterations = 0
        
        # Main optimization loop for this restart
        while failures < max_failures:
            local_iterations += 1
            total_iterations += 1
            
            # Generate a new solution by tweaking the current one
            new_sol = tweak(sol, N)
            while not is_valid_solution(new_sol, items, capacity) and failures < max_failures:
                failures += 1
                new_sol = tweak(sol, N)
            new_value = score(new_sol, items, capacity)
            # new_valid = is_valid_solution(new_sol, items, capacity)
            
            # Add point to tracking
            
            
            # Check if new solution is better
            if new_value > value:
                failures = 0
                sol = new_sol
                value = new_value
                
                # Update best if needed
                if value > best_value:
                    best_sol = sol.copy()
                    best_value = value
                    
                # Update best scores list
                iterations_list.append(total_iterations)
                solution_scores_list.append(new_value)
                best_scores_list.append(best_value)
                show_knapsack(sol, items, capacity, value, best_value, restart, local_iterations, new_sol)
            else:
                failures += 1
            
            # Add to best scores list
            
            
            # Visualize progress (but not every iteration to avoid slowdowns)
            # if local_iterations % 5 == 0 or failures == 0:  # Show when improved or every 5 iterations
            
            # time.sleep(1)
            
            # Print progress periodically
            # if local_iterations % 50 == 0:
            print(f"Restart {restart}, Iteration {local_iterations}: "
                    f"Items: {len(sol)}/{N}, Value: {value:.2f}, "
                    f"Best overall: {best_value:.2f}")
        
        print(f"Completed restart {restart}: best score={value:.2f}, "
              f"iterations={local_iterations}")
        
        # Show final state for this restart
        show_knapsack(sol, items, capacity, value, best_value, restart, local_iterations, valid)
        # time.sleep(1)  # Pause to see final state of this restart
    

# Main function to run the visualization
def main():
    # Set random seed for reproducibility
    
    
    # Generate problem instance
    N = 1000  # Number of items - supporting up to 10,000 items as requested
    items, capacity = random_data(N)
    
    print(f"Generated {N} items with total capacity: {capacity:.2f}")
    print("Items (Weight, Value):")
    for i, item in enumerate(items):
        print(f"Item {i}: Weight={item[0]:.2f}, Value={item[1]:.2f}")
    
    # Run hill climbing with visualization
    hill_climbing_with_restarts(items, capacity, max_failures=1000, use_greedy=True)

if __name__ == "__main__":
    main()