import random
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes  # For proper inset axes
import time
import math  # Required for math.exp() in simulated annealing

def g1(L,d,w):
    return 1 - d**3 * L / (7178 * w**4)

def g2(L,d,w):
    return (4*d**2 - w*d)/(12566 * d * w**3 - w**4) + 1/(5108*w**2) - 1

def g3(L,d,w):
    return 1 - 140.45 * w / (d**2*L)

def g4(L,d,w):
    return (w+d)/1.5 - 1

def satisfies_constraints(L,d,w):
    return g1(L,d,w) <= 0 and g2(L,d,w) <= 0 and g3(L,d,w) <= 0 and g4(L,d,w) <= 0

def score(L,d,w):
    return (2+L)*d*w**2

def tweak(L,d,w):
    delta_w = 0.01
    delta_d = 0.01
    delta_L = 0.1
    
    new_w = w + random.uniform(-1, 1) * delta_w
    while new_w < 0.05 or new_w > 2:
        new_w = w + random.uniform(-1, 1) * delta_w
    
    new_d = d + random.uniform(-1, 1) * delta_d
    while new_d < 0.25 or new_d > 1.3:
        new_d = d + random.uniform(-1, 1) * delta_d
        
    new_L = L + random.uniform(-1, 1) * delta_L
    while new_L < 2 or new_L > 15:
        new_L = L + random.uniform(-1, 1) * delta_L
        
    return (new_L, new_d, new_w)

def random_solution():
    return (
        random.uniform(2, 15),
        random.uniform(0.25, 1.3),
        random.uniform(0.05, 2),
    )
        
# Initialize the figure for animation
fig = plt.figure(figsize=(16, 9))
ax1 = fig.add_subplot(121, projection='3d')  # Left subplot for spring
ax2 = fig.add_subplot(122)  # Right subplot for convergence plot
ax1.set_xlim([-2, 2])
ax1.set_ylim([-2, 2])
ax1.set_zlim([0, 15])
plt.ion()  # Turn on interactive mode

# Lists to store convergence data
iter_count = 0
iterations_list = []
accepted_scores_list = []
best_scores_so_far = []

# Function to update the spring visualization with simulated annealing information
def show_spring_with_sa_info(L, d, w, current_score, best_score, generation, temp, acceptance_rate):
    if random.randint(1,1000) > 1:
        return
    # Clear and set up the 3D spring visualization
    ax1.clear()
    ax1.set_xlim([-2, 2])
    ax1.set_ylim([-2, 2])
    ax1.set_zlim([0, 15])
    
    # Generate spring data
    N = L
    theta = np.linspace(0, N * 2 * np.pi, 1000)
    z = theta / (2 * np.pi)
    x = d * np.sin(theta)
    y = d * np.cos(theta)
    
    # Plot the spring
    ax1.plot(x, y, z, color='b', lw=w*30)
    
    # Calculate constraint values
    g1_val = g1(L, d, w)
    g2_val = g2(L, d, w)
    g3_val = g3(L, d, w)
    g4_val = g4(L, d, w)
    
    # Update title with parameters
    ax1.set_title(f"Spring Optimization (L={L:.4f}, d={d:.4f}, w={w:.4f})")
    
    # Add text annotation for constraints and scores
    constraint_text = (
        f"Generation: {generation}\n"
        f"Temperature: {temp:.6f}\n"
        f"Current Score: {current_score:.8f}\n"
        f"Best Score: {best_score:.8f}\n"
        f"Worse Accepted: {acceptance_rate:.2f}%\n"
        f"g1: {g1_val:.4f} {'✓' if g1_val <= 0 else '✗'}\n"
        f"g2: {g2_val:.4f} {'✓' if g2_val <= 0 else '✗'}\n"
        f"g3: {g3_val:.4f} {'✓' if g3_val <= 0 else '✗'}\n"
        f"g4: {g4_val:.4f} {'✓' if g4_val <= 0 else '✗'}"
    )
    ax1.text2D(0.02, 0.95, constraint_text, transform=ax1.transAxes, 
             fontsize=10, verticalalignment='top', 
             bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))
    
    # Update the convergence plot - SIMPLIFIED
    ax2.clear()
    ax2.set_title("Simulated Annealing Progress")
    ax2.set_xlabel("Iterations")
    ax2.set_ylabel("Score")
    
    # Plot only the requested elements
    if iterations_list and len(iterations_list) == len(accepted_scores_list):
        # Plot each accepted solution as a blue point
        ax2.scatter(iterations_list, accepted_scores_list, color='blue', s=10 if len(iterations_list) < 1000 else 2, alpha=1, label='Accepted Solutions')
        
        # Plot the best score seen so far as a green line
        if best_scores_so_far and len(iterations_list) == len(best_scores_so_far):
            ax2.plot(iterations_list, best_scores_so_far, 'g-', linewidth=2, label='Best Score')
        
        ax2.legend(loc='upper right')
    
    # Draw and update the figure
    fig.tight_layout()
    fig.canvas.draw()
    fig.canvas.flush_events()
    plt.pause(0.000001)

# Simulated annealing algorithm with visualization
def simulated_annealing():
    global iter_count, iterations_list, accepted_scores_list, best_scores_so_far
    
    # Reset tracking lists
    iter_count = 0
    iterations_list = []
    accepted_scores_list = []
    best_scores_so_far = []
    
    # Simulated annealing parameters (from your preferred implementation)
    initial_temp = 0.1
    alpha = 0.99
    final_temp = initial_temp / 100000
    trials_per_temp = 100
    
    # Initialize solution
    start = random_solution()
    while not satisfies_constraints(*start):
        start = random_solution()
    
    sol = start
    value = score(*sol)
    temp = initial_temp
    generation = 0
    best_sol = sol
    best_score = value
    
    # Initialize the first point in our tracking
    iter_count += 1
    iterations_list.append(iter_count)
    accepted_scores_list.append(value)
    best_scores_so_far.append(value)  # Initial solution is the best so far
    
    # Show initial state
    show_spring_with_sa_info(*sol, value, best_score, generation, temp, 0)
    time.sleep(2)  # Pause to see initial state
    
    # Main simulated annealing loop
    while temp >= final_temp:
        generation += 1
        accepted_worse = 0
        total_worse = 0
        
        for i in range(trials_per_temp):
            new_sol = tweak(*sol)
            while not satisfies_constraints(*new_sol):
                new_sol = tweak(*sol)
            new_value = score(*new_sol)
            
            delta = new_value - value
            delta *= -1  # Convert to maximization problem (negative delta means worse solution)
            
            if delta >= 0:  # Better solution
                sol = new_sol
                value = new_value
                
                # Record this accepted solution
                iter_count += 1
                iterations_list.append(iter_count)
                accepted_scores_list.append(value)
                
                # Update best score seen so far
                if best_score is None or value < best_score:
                    best_sol = sol
                    best_score = value
                
                # Track best score for each iteration
                best_scores_so_far.append(best_score)
                
                # Visualize immediately when we find a new best solution
                current_acceptance_rate = 0 if total_worse == 0 else (accepted_worse / total_worse * 100)
                show_spring_with_sa_info(*sol, value, best_score, generation, temp, current_acceptance_rate)
            else:  # Worse solution
                total_worse += 1
                p = math.exp(delta/temp)
                r = random.random()
                if r <= p:
                    # print(f"accepted worse score: {value:.8f} vs {new_value:.8f} (delta = {delta:.8f})")
                    accepted_worse += 1
                    sol = new_sol
                    value = new_value
                    
                    # Record this accepted worse solution
                    iter_count += 1
                    iterations_list.append(iter_count)
                    accepted_scores_list.append(value)
                    
                    # Track best score for each iteration (unchanged for worse solutions)
                    best_scores_so_far.append(best_score)
            
            # Periodically update the visualization
            current_acceptance_rate = 0 if total_worse == 0 else (accepted_worse / total_worse * 100)
            show_spring_with_sa_info(*sol, value, best_score, generation, temp, current_acceptance_rate)
        
        # Calculate final acceptance rate for this generation
        acceptance_rate = 0 if total_worse == 0 else (accepted_worse / total_worse * 100)
        
        # Update tracking lists for plotting - simplified, just for this generation
        
        # Update visualization with final state for this temperature
        show_spring_with_sa_info(*sol, value, best_score, generation, temp, acceptance_rate)
        
        # Print progress
        print(
            f"Gen #{generation}: temp = {temp:.6f}, "
            f"best score = {best_score:.8f}, "
            f"cur score = {value:.8f}, "
            f"worse accepted = {round(accepted_worse/total_worse*100,2) if total_worse > 0 else 0:.2f}%"
        )
        
        # Cool down the temperature
        temp = temp * alpha
    
    # Show final state
    show_spring_with_sa_info(*best_sol, best_score, best_score, generation, temp, acceptance_rate)
    print(f"Final result: L={best_sol[0]:.8f}, d={best_sol[1]:.8f}, w={best_sol[2]:.8f}, score={best_score:.8f}")
    # time.sleep(5)  # Pause to see final state
    plt.show(block=True)
    
    return best_sol, best_score, generation

# Run the simulated annealing algorithm
if __name__ == "__main__":
    simulated_annealing()