import random
import numpy as np
import matplotlib.pyplot as plt
import time

from sympy import im

def g1(L,d,w):
    return 1 - d**3 * L / (7178 * w**4)

def g2(L,d,w):
    return (4*d**2 - w*d)/(12566 * d * w**3 - w**4) + 1/(5108*w**2) - 1

def g3(L,d,w):
    return 1 - 140.45 * w / (d**2*L)

def g4(L,d,w):
    return (w+d)/1.5 - 1

def satisfies_constraints(L,d,w):
    return g1(L,d,w) <= 0 and g2(L,d,w) <= 0 and g3(L,d,w) <= 0 and g4(L,d,w) <= 0

def score(L,d,w):
    return (2+L)*d*w**2

def tweak(L,d,w):
    delta_w = 0.01
    delta_d = 0.01
    delta_L = 0.1
    
    new_w = w + random.uniform(-1, 1) * delta_w
    while new_w < 0.05 or new_w > 2:
        new_w = w + random.uniform(-1, 1) * delta_w
    
    new_d = d + random.uniform(-1, 1) * delta_d
    while new_d < 0.25 or new_d > 1.3:
        new_d = d + random.uniform(-1, 1) * delta_d
        
    new_L = L + random.uniform(-1, 1) * delta_L
    while new_L < 2 or new_L > 15:
        new_L = L + random.uniform(-1, 1) * delta_L
        
    return (new_L, new_d, new_w)

def random_solution():
    return (
        random.uniform(2, 15),
        random.uniform(0.25, 1.3),
        random.uniform(0.05, 2),
    )
        
# Initialize the figure for animation
fig = plt.figure(figsize=(16, 9))
ax1 = fig.add_subplot(121, projection='3d')  # Left subplot for spring
ax2 = fig.add_subplot(122)  # Right subplot for convergence plot
ax1.set_xlim([-2, 2])
ax1.set_ylim([-2, 2])
ax1.set_zlim([0, 15])
plt.ion()  # Turn on interactive mode

# Lists to store convergence data
iterations_list = []
scores_list = []
restart_markers = []
best_scores_list = []

# Function to update the spring visualization with restart information
def show_spring_with_restart_info(L, d, w, current_score, best_score, restart_num, redraw_spring=True):
    
    # Generate spring data
    if redraw_spring:
        # Clear and set up the 3D spring visualization
        ax1.clear()
        ax1.set_xlim([-2, 2])
        ax1.set_ylim([-2, 2])
        ax1.set_zlim([0, 15])
        
        N = L
        theta = np.linspace(0, N * 2 * np.pi, 1000)
        z = theta / (2 * np.pi)
        x = d * np.sin(theta)
        y = d * np.cos(theta)
        
        # Plot the spring
        ax1.plot(x, y, z, color='b', lw=w*30)
    
        # Calculate constraint values
        g1_val = g1(L, d, w)
        g2_val = g2(L, d, w)
        g3_val = g3(L, d, w)
        g4_val = g4(L, d, w)
        
        # Update title with parameters
        ax1.set_title(f"Spring Optimization (L={L:.4f}, d={d:.4f}, w={w:.4f})")
        
        # Add text annotation for constraints and scores
        constraint_text = (
            f"Restart: {restart_num}\n"
            f"Current Score: {current_score:.8f}\n"
            f"Best Score: {best_score:.8f}\n"
            f"g1: {g1_val:.4f} {'✓' if g1_val <= 0 else '✗'}\n"
            f"g2: {g2_val:.4f} {'✓' if g2_val <= 0 else '✗'}\n"
            f"g3: {g3_val:.4f} {'✓' if g3_val <= 0 else '✗'}\n"
            f"g4: {g4_val:.4f} {'✓' if g4_val <= 0 else '✗'}"
        )
        ax1.text2D(0.02, 0.95, constraint_text, transform=ax1.transAxes, 
                fontsize=10, verticalalignment='top', 
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))
    
    # Update the convergence plot
    ax2.clear()
    ax2.set_title("Optimization Progress")
    ax2.set_xlabel("Iterations")
    ax2.set_ylabel("Score")
    
    # Plot all scores
    if iterations_list:
        ax2.plot(iterations_list, scores_list, 'b.', alpha=0.5, label='Current Run')
        
        # Plot restart markers
        for i, marker in enumerate(restart_markers):
            if i < len(restart_markers) - 1:  # Not the current restart
                ax2.axvline(x=marker, color='r', linestyle='--', alpha=0.5)
                ax2.text(marker, min(scores_list), f"R{i+1}", verticalalignment='bottom', 
                         horizontalalignment='center', color='r', fontsize=8)
        
        # Plot best score line
        ax2.plot(iterations_list, best_scores_list, 'g-', label='Best Overall')
        
        ax2.legend(loc='upper right')
    
    # Draw and update the figure
    fig.tight_layout()
    fig.canvas.draw()
    fig.canvas.flush_events()
    plt.pause(0.000001)

# Hill climbing with random restarts
def hill_climbing_with_restarts(max_failures):
    global iterations_list, scores_list, restart_markers, best_scores_list
    
    # Reset tracking lists
    iterations_list = []
    scores_list = []
    restart_markers = [0]  # Start of first restart
    best_scores_list = []
    
    best_sol = None
    best_value = float('inf')
    total_iterations = 0
    iteration = 0
    
    while True:
        iteration += 1
        print(f"Starting restart {iteration}")
        
        # Get a random starting solution
        start = random_solution()
        while not satisfies_constraints(*start):
            start = random_solution()
        
        sol = start
        value = score(*sol)
        
        # Update visualization for the start of this restart
        if iteration > 0:
            restart_markers.append(total_iterations)
        
        # Add initial point to tracking
        iterations_list.append(total_iterations)
        scores_list.append(value)
        best_scores_list.append(min(best_value, value))
        
        # Show initial state
        if best_sol is None or value < best_value:
            best_sol = sol
            best_value = value
        
        show_spring_with_restart_info(*sol, value, best_value, iteration)
        time.sleep(2)  # Pause to see initial state
        
        failures = 0
        local_iterations = 0
        since_redraw = 0
        
        # Main optimization loop for this restart
        while failures < max_failures:
            since_redraw += 1
            failures += 1
            local_iterations += 1
            total_iterations += 1
            
            new_sol = tweak(*sol)
            
            while not satisfies_constraints(*new_sol):
                new_sol = tweak(*sol)
                
            new_value = score(*new_sol)
            
            # Add point to tracking
            iterations_list.append(total_iterations)
            scores_list.append(new_value)

            improvement = False
            
            if new_value < value:
                improvement = True
                failures = 0
                sol = new_sol
                value = new_value
                
                # Update best overall if needed
                if value < best_value:
                    best_sol = sol
                    best_value = value
            
            # Update best score list
            best_scores_list.append(best_value)
            
            # if since_redraw > 100:
            #     show_spring_with_restart_info(*sol, value, best_value, iteration, True)
            #     since_redraw = 0
            # else:
            if random.randint(1, 1000) == 1 or improvement:
                show_spring_with_restart_info(*sol, value, best_value, iteration, improvement)
            
            # Print progress
            print(f"Restart {iteration}, Iteration {local_iterations}: L={sol[0]:.4f}, "
                    f"d={sol[1]:.4f}, w={sol[2]:.4f}, score={value:.4f}, "
                    f"best overall={best_value:.4f}")
        
        print(f"Completed restart {iteration}: best score={value:.8f}, "
              f"iterations={local_iterations}")
        
        # Show final state for this restart
        show_spring_with_restart_info(*sol, value, best_value, iteration)
        time.sleep(2.0)  # Pause to see final state of this restart
    return best_sol, best_value, total_iterations

hill_climbing_with_restarts(max_failures=100_000)
