from operator import is_
import random
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle, Circle
import time
import math

random.seed(42)
np.random.seed(42)

# Functions provided by the professor
def random_item():
    weight = round(random.random() * 50, 2)
    value = round(np.random.normal(weight, 10, 1)[0], 2)
    while value < 0:
        value = round(np.random.normal(weight, 10, 1)[0], 2)
    return (weight, float(value))

def random_data(N):
    items = [random_item() for i in range(N)]
    capacity = max(1, round(random.random() * sum(item[0] for item in items) * 0.7))
    return (items, capacity)

def random_solution(N):
    sol = set()
    for i in range(N):
        if random.randint(0, 1) == 0:
            sol.add(i)
    return sol

def greedy_solution(items, capacity):
    by_density = sorted(enumerate(items), key=lambda e: e[1][1]/e[1][0], reverse=True)
    remaining_capacity = capacity
    sol = set()
    for (index, item) in by_density:
        if item[0] <= remaining_capacity:
            sol.add(index)
            remaining_capacity -= item[0]  # Subtract weight, not value
    return sol

# def tweak(sol, N):
#     index = random.randint(0, N-1)  # Choose a random item
#     new_sol = set(sol)
#     if index in new_sol:
#         new_sol.remove(index)
#     else:
#         new_sol.add(index)
#     return new_sol

def tweak(sol, N):
    new_sol = set(sol)
    num_remove = random.randint(0, 3)  # Choose between 0 and 3 items to remove
    for _ in range(num_remove):
        if new_sol:
            index_to_remove = random.choice(list(new_sol))
            new_sol.remove(index_to_remove)
    
    # Add a new item
    index_to_add = random.randint(0, N-1)
    new_sol.add(index_to_add)
    
    return new_sol


def score(sol, items, capacity):
    total_weight = sum(items[ind][0] for ind in sol)
    if total_weight > capacity:
        return 0  # Invalid solution
    return sum(items[ind][1] for ind in sol)

# Additional helper functions
def get_total_weight(sol, items):
    return sum(items[ind][0] for ind in sol)

def is_valid_solution(sol, items, capacity):
    return get_total_weight(sol, items) <= capacity

# Initialize the figure for animation
fig = plt.figure(figsize=(16, 9))
ax1 = fig.add_subplot(121)  # Left subplot for knapsack
ax2 = fig.add_subplot(122)  # Right subplot for convergence plot
# plt.ion()  # Turn on interactive mode

# Lists to store convergence data
iterations_list = []
restart_markers = []
best_scores_list = []
solution_scores_list = []
solution_valid_list = []  # Track which solutions are valid
temperature_list = []     # Track temperature over time
accepted_worse_list = []  # Track accepted worse solutions

# Function to update the knapsack visualization
def show_knapsack(sol, items, capacity, current_score, best_score, temp, iteration, valid, acceptance_rate=None):
    # Clear and set up the knapsack visualization
    ax1.clear()
    ax1.set_xlim([0, 10])
    ax1.set_ylim([0, 10])
    
    # Calculate total weight
    total_weight = get_total_weight(sol, items)
    
    # Draw capacity bar at the bottom
    capacity_bar_height = 0.3
    capacity_width = 8
    
    # Background bar (total capacity)
    capacity_bg = Rectangle((1, 0.3), capacity_width, capacity_bar_height, 
                            fill=True, color='lightgray')
    ax1.add_patch(capacity_bg)
    
    # Filled bar (current usage)
    usage_percent = min(1.0, total_weight / capacity)
    usage_width = capacity_width * usage_percent
    usage_color = 'green' if usage_percent <= 1.0 else 'red'
    usage_bar = Rectangle((1, 0.3), usage_width, capacity_bar_height, 
                          fill=True, color=usage_color)
    ax1.add_patch(usage_bar)
    
    # Add capacity text
    ax1.text(1 + capacity_width/2, 0.15, f'Weight: {total_weight:.2f} / {capacity:.2f}',
             horizontalalignment='center')
    
    # Calculate grid dimensions based on number of items
    N = len(items)
    grid_cols = min(100, int(math.ceil(math.sqrt(N))))
    grid_rows = int(math.ceil(N / grid_cols))
    
    # Define display area dimensions
    display_width = 8
    display_height = 8
    
    # Calculate square size to fit the grid in the display area
    square_size_x = display_width / grid_cols
    square_size_y = display_height / grid_rows
    square_size = min(square_size_x, square_size_y)  # Use the smaller dimension to ensure squares fit
    
    # Center the grid in the display area
    start_x = 1
    start_y = 1
    
    # Draw items as squares in a grid
    for i in range(min(N, 10000)):  # Limit to 10,000 items max
        # Calculate grid position
        row = i // grid_cols
        col = i % grid_cols
        
        # Calculate position (squares touch with no padding)
        x = start_x + col * square_size
        y = start_y + row * square_size
        
        # Determine color based on whether item is in solution and if solution is valid
        if i in sol:
            # If capacity is exceeded, make ALL items in solution RED instead of blue
            color = 'red' if not valid else 'blue'
        else:
            color = 'lightgray'  # OUT of the solution
        
        # Draw the square (no text)
        square = Rectangle((x, y), square_size, square_size, 
                          fill=True, color=color, linewidth=0)
        ax1.add_patch(square)
    
    # Add title and information text
    ax1.set_title("Knapsack Problem Visualization")
    
    # Add information text including temperature and acceptance rate
    acceptance_text = f"Acceptance Rate: {acceptance_rate:.2f}%" if acceptance_rate is not None else ""
    info_text = (
        f"Iteration: {iteration}\n"
        f"Status: {'Valid' if valid else 'INVALID'}\n"
        f"Temperature: {temp:.6f}\n"
        f"{acceptance_text}\n"
        f"Current Value: {current_score:.2f}\n"
        f"Best Value: {best_score:.2f}\n"
        f"Items in Knapsack: {len(sol)} / {len(items)}"
    )
    
    ax1.text(0.02, 0.98, info_text, transform=ax1.transAxes, 
             fontsize=10, verticalalignment='top', 
             bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))
    
    # Create a legend for the grid
    in_valid_square = Rectangle((0, 0), 1, 1, color='blue', alpha=0.7)
    in_invalid_square = Rectangle((0, 0), 1, 1, color='red', alpha=0.7)
    out_square = Rectangle((0, 0), 1, 1, color='lightgray', alpha=0.7)
    ax1.legend([in_valid_square, in_invalid_square, out_square], 
              ['In Knapsack (Valid)', 'In Knapsack (Invalid)', 'Not in Knapsack'], 
              loc='upper right')
    
    # Update the convergence plot
    ax2.clear()
    ax2.set_title("Simulated Annealing Progress")
    ax2.set_xlabel("Iterations")
    ax2.set_ylabel("Value")
    
    # Plot all scores and best scores
    if iterations_list:
        # Create separate lists for valid and invalid solutions
        valid_iterations = []
        valid_scores = []
        invalid_iterations = []
        invalid_scores = []
        
        for i, valid_flag in enumerate(solution_valid_list):
            if valid_flag:
                valid_iterations.append(iterations_list[i])
                valid_scores.append(solution_scores_list[i])
            else:
                invalid_iterations.append(iterations_list[i])
                invalid_scores.append(solution_scores_list[i])
        
        # Plot valid solutions with blue circles
        if valid_iterations:
            ax2.scatter(valid_iterations, valid_scores, s=20, color='blue', alpha=0.5, 
                       label='Valid Solutions')
        
        # Plot invalid solutions with larger red squares
        if invalid_iterations:
            ax2.scatter(invalid_iterations, invalid_scores, s=60, color='red', alpha=0.5, 
                       marker='s', label='Invalid Solutions')
        
        # Plot accepted worse solutions with yellow triangles
        if accepted_worse_list:
            worse_iterations = accepted_worse_list  # These are already the iteration indices
            worse_scores = [solution_scores_list[iterations_list.index(i)] for i in accepted_worse_list]
            ax2.scatter(worse_iterations, worse_scores, s=40, color='yellow', alpha=0.8,
                       marker='^', edgecolor='black', label='Accepted Worse')
        
        # Plot best score line
        ax2.plot(iterations_list, best_scores_list, 'g-', linewidth=2, label='Best Solution')
        
        ax2.legend(loc='upper left')
    
    # Draw and update the figure
    fig.tight_layout()
    fig.canvas.draw()
    fig.canvas.flush_events()
    plt.pause(0.000001)

# Simulated annealing for knapsack problem
def simulated_annealing(items, capacity, use_greedy=True):
    global iterations_list, solution_scores_list, solution_valid_list, best_scores_list, accepted_worse_list
    
    # Reset tracking lists
    iterations_list = []
    solution_scores_list = []
    solution_valid_list = []
    best_scores_list = []
    accepted_worse_list = []
    
    # Simulated annealing parameters
    initial_temp = 50.0
    alpha = 0.99  # Cooling rate
    final_temp = 0.001
    trials_per_temp = 1000
    
    N = len(items)
    
    # Initialize solution - either greedy or random
    if use_greedy:
        sol = greedy_solution(items, capacity)
    else:
        sol = []
    
    value = score(sol, items, capacity)
    valid = is_valid_solution(sol, items, capacity)
    
    best_sol = sol.copy()
    best_value = value
    
    temp = initial_temp
    iteration = 0
    
    # Add initial solution to tracking
    iterations_list.append(iteration)
    solution_scores_list.append(value)
    solution_valid_list.append(valid)
    best_scores_list.append(best_value)
    
    # If there are no iterations yet, add the initial solution
    show_knapsack(sol, items, capacity, value, best_value, temp, iteration, valid)
    
    # Main simulated annealing loop
    while temp > final_temp:
        accepted_worse = 0
        total_worse = 0
        
        for _ in range(trials_per_temp):
            iteration += 1
            
            # Generate a new solution by tweaking the current one
            new_sol = tweak(sol, N)
            while not is_valid_solution(new_sol, items, capacity):
                new_sol = tweak(sol, N)
            new_value = score(new_sol, items, capacity)
            
            # Calculate improvement
            delta = new_value - value  # Positive delta means better solution
            
            # Decide whether to accept the new solution
            accept = False
            
            if delta >= 0:  # Better solution, always accept
                accept = True
            else:  # Worse solution, accept with probability
                total_worse += 1
                # Calculate acceptance probability
                p = math.exp(delta / temp)
                r = random.random()
                
                if r < p:  # Accept worse solution
                    accept = True
                    accepted_worse += 1
                    # Store actual iteration number, not index
                    accepted_worse_list.append(iteration)  
            
            if accept:
                sol = new_sol
                value = new_value
                
                # Update best if valid and better
                if valid and value > best_value:
                    best_sol = sol.copy()
                    best_value = value
            
            # Only update tracking when a solution is accepted
            if accept:
                iterations_list.append(iteration)
                solution_scores_list.append(value)
                solution_valid_list.append(valid)
                best_scores_list.append(best_value)
            
            # Calculate current acceptance rate
            acceptance_rate = 0 if total_worse == 0 else (accepted_worse / total_worse * 100)
            
            # Only visualize when a solution is accepted
            if accept and random.randint(1,10000) == 1:
                show_knapsack(sol, items, capacity, value, best_value, temp, iteration, valid, acceptance_rate)
            
            # Print progress periodically
            if iteration % 50 == 0:
                print(f"Iteration {iteration}: temp={temp:.6f}, value={value:.2f}, "
                     f"best={best_value:.2f}, accept_rate={acceptance_rate:.2f}%")
        
        # Cool down temperature
        temp *= alpha
        
        # Visualize after temperature change (only if we had accepted solutions in this round)
        # if accepted_worse > 0 or total_worse != trials_per_temp:
        #     show_knapsack(sol, items, capacity, value, best_value, temp, iteration, valid, acceptance_rate)
        
        print(f"Cooled to temp={temp:.6f}, current value={value:.2f}, "
              f"best value={best_value:.2f}, acceptance rate={acceptance_rate:.2f}%")
    
    # Show final best solution
    print("\nSimulated Annealing complete!")
    print(f"Best solution found: value={best_value:.2f}, valid={is_valid_solution(best_sol, items, capacity)}")
    print(f"Items in knapsack: {len(best_sol)}/{N}")
    
    show_knapsack(best_sol, items, capacity, best_value, best_value, temp, iteration, 
                 is_valid_solution(best_sol, items, capacity), 0)
    
    plt.ioff()  # Turn off interactive mode
    plt.show(block=True)  # Show the final plot
    
    return best_sol, best_value, iteration

# Main function to run the visualization
def main():
    # Generate problem instance
    N = 1000  # Number of items
    items, capacity = random_data(N)
    
    print(f"Generated {N} items with total capacity: {capacity:.2f}")
    print("Items (Weight, Value):")
    for i, item in enumerate(items):
        print(f"Item {i}: Weight={item[0]:.2f}, Value={item[1]:.2f}")
    
    # Run simulated annealing with visualization
    simulated_annealing(items, capacity, use_greedy=False)

if __name__ == "__main__":
    main()