import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Button
from mpl_toolkits.mplot3d import Axes3D


# Define the function f(x,y) = x*cos(x)*(sin(y))^2
def f(x, y):
    return x * np.cos(x) * (np.sin(y)) ** 2


# Define the gradient of f(x,y)
def gradient(x, y):
    # df/dx = (cos(x) - x*sin(x))*(sin(y))^2
    df_dx = (np.cos(x) - x * np.sin(x)) * (np.sin(y)) ** 2

    # df/dy = 2*x*cos(x)*sin(y)*cos(y)
    df_dy = 2 * x * np.cos(x) * np.sin(y) * np.cos(y)

    return np.array([df_dx, df_dy])


# Set up the figure and plots
fig = plt.figure(figsize=(12, 8))
ax1 = fig.add_subplot(121, projection="3d")  # 3D surface plot
ax2 = fig.add_subplot(122)  # 2D contour plot

# Create x and y meshgrid for the function surface
x = np.linspace(-2, 0, 100)
y = np.linspace(0, 3, 100)
X, Y = np.meshgrid(x, y)
Z = f(X, Y)

# Plot the 3D surface
surface = ax1.plot_surface(X, Y, Z, cmap="viridis", alpha=0.8)
ax1.set_xlabel("X")
ax1.set_ylabel("Y")
ax1.set_zlabel("f(X,Y)")
ax1.set_title("Function: f(x,y) = x*cos(x)*(sin(y))^2")

# Plot the 2D contour
contour = ax2.contourf(X, Y, Z, 50, cmap="viridis")
plt.colorbar(contour, ax=ax2)
ax2.set_xlabel("X")
ax2.set_ylabel("Y")
ax2.set_title("Contour Plot with Gradient Descent")
ax2.set_xlim([-2, 0])
ax2.set_ylim([0, 3])

# Initialize the starting point
current_x, current_y = -1.5, 2.5
current_z = f(current_x, current_y)

# Plot the initial point
point_3d = ax1.scatter(current_x, current_y, current_z, color="red", s=100)
point_2d = ax2.scatter(current_x, current_y, color="red", s=100)

# Calculate the gradient and negative gradient vector
grad = gradient(current_x, current_y)
# Plot negative gradient arrow (direction of descent) in 2D
grad_arrow = ax2.quiver(
    current_x, current_y, -grad[0], -grad[1], color="red", scale_units="xy", scale=1
)
# Plot negative gradient arrow in 3D
grad_arrow_3d = ax1.quiver(
    current_x,
    current_y,
    current_z,
    -grad[0],
    -grad[1],
    0,  # Using 0 for z-component for simplicity
    color="red",
    length=0.2,
    normalize=True,
)

# Initialize the iteration counter and display information
iteration = 0
text_iter = ax2.text(-1.9, 2.8, f"Iteration: {iteration}", fontsize=10)
text_point = ax2.text(
    -1.9, 2.6, f"Point: ({current_x:.4f}, {current_y:.4f})", fontsize=10
)
text_gradient = ax2.text(
    -1.9, 2.4, f"Gradient: ({grad[0]:.4f}, {grad[1]:.4f})", fontsize=10
)
text_function = ax2.text(-1.9, 2.2, f"f(x,y): {current_z:.4f}", fontsize=10)

# We already initialized grad_arrow_3d above, so no need for additional initialization here

# Add a button for stepping through gradient descent
button_ax = plt.axes([0.81, 0.05, 0.1, 0.075])
button = Button(button_ax, "Step")


# Function to update the plot for each step of gradient descent
def step_gradient_descent(event):
    global current_x, current_y, current_z, iteration, point_3d, point_2d, grad_arrow, grad_arrow_3d

    # Calculate gradient
    grad = gradient(current_x, current_y)

    # Update position using gradient descent (negative gradient direction)
    # The step size is 1/100 of the gradient
    current_x -= 0.05 * grad[0]  # Move in NEGATIVE gradient direction for DESCENT
    current_y -= 0.05 * grad[1]  # Move in NEGATIVE gradient direction for DESCENT

    # Ensure we stay within the specified range
    current_x = np.clip(current_x, -2, 0)
    current_y = np.clip(current_y, 0, 3)

    # Calculate new function value
    current_z = f(current_x, current_y)

    # Update iteration counter
    iteration += 1

    # Update 3D plot point - properly update the scatter plot
    point_3d._offsets3d = ([current_x], [current_y], [current_z])

    # Update 2D plot point
    point_2d.set_offsets([current_x, current_y])

    # Update gradient arrows with NEGATIVE gradient (descent direction)
    grad = gradient(current_x, current_y)

    # Update 2D arrow
    grad_arrow.remove()
    grad_arrow = ax2.quiver(
        current_x, current_y, -grad[0], -grad[1], color="red", scale_units="xy", scale=1
    )

    # Update 3D arrow (first remove the old one if it exists)
    if hasattr(ax1, "collections"):
        for collection in ax1.collections:
            if collection == grad_arrow_3d:
                collection.remove()

    grad_arrow_3d = ax1.quiver(
        current_x,
        current_y,
        current_z,
        -grad[0],
        -grad[1],
        0,  # Using 0 for z-component for simplicity
        color="red",
        length=0.2,
        normalize=True,
    )

    # Update text information
    text_iter.set_text(f"Iteration: {iteration}")
    text_point.set_text(f"Point: ({current_x:.4f}, {current_y:.4f})")
    text_gradient.set_text(f"Gradient: ({grad[0]:.4f}, {grad[1]:.4f})")
    text_function.set_text(f"f(x,y): {current_z:.4f}")

    # Redraw the figure
    fig.canvas.draw_idle()


# Connect the button click event to the step function
button.on_clicked(step_gradient_descent)


# Enable keyboard interaction ('Enter' key)
def on_key(event):
    if event.key == "enter":
        step_gradient_descent(event)


fig.canvas.mpl_connect("key_press_event", on_key)

# Add title to the figure
plt.suptitle("Gradient Descent Visualization", fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.95])

plt.show()
