import pygame
import numpy as np
import math
import random
import sys
import time
import traceback
import copy  # For deep copying controllers

random.seed(3)
np.random.seed(3)

# Global variable to track fast mode state
FAST_MODE = False
TRACK_DIFFICULTY = "hard"  # ["easy", "medium", "hard", "random"]

# Track representation
class Track:
    def __init__(self, outer_points, inner_points):
        self.outer_points = outer_points
        self.inner_points = inner_points
        self.outer_segments = self._create_segments(outer_points)
        self.inner_segments = self._create_segments(inner_points)
        self.all_segments = self.outer_segments + self.inner_segments
        self.start_finish = (outer_points[0], inner_points[0])
        self.difficulty = "medium"  # Default difficulty
        self.special_features = []  # List of special track features
        
    def _create_segments(self, points):
        segments = []
        for i in range(len(points)):
            p1 = points[i]
            p2 = points[(i + 1) % len(points)]
            segments.append((p1, p2))
        return segments
    
    def draw(self, screen):
        # Draw the track based on difficulty
        outer_color = (255, 215, 0)  # Default yellow
        inner_color = (255, 215, 0)
        
        # Color code by difficulty
        if hasattr(self, 'difficulty'):
            if self.difficulty == "easy":
                outer_color = (0, 255, 0)  # Green for easy
                inner_color = (0, 255, 0)
            elif self.difficulty == "medium":
                outer_color = (255, 165, 0)  # Orange for medium
                inner_color = (255, 165, 0)
            elif self.difficulty == "hard":
                outer_color = (255, 0, 0)  # Red for hard
                inner_color = (255, 0, 0)
        
        pygame.draw.lines(screen, outer_color, True, self.outer_points, 2)
        pygame.draw.lines(screen, inner_color, True, self.inner_points, 2)
        pygame.draw.line(screen, (255, 255, 255), 
                         self.start_finish[0], self.start_finish[1], 2)
        
        # Draw special features
        for feature in self.special_features:
            feature_type = feature[0]
            if feature_type == "split_path":
                start_angle, end_angle = feature[1], feature[2]
                
                # Draw indicators for split path
                for i in range(len(self.outer_points)):
                    angle = 2 * math.pi * i / len(self.outer_points)
                    if start_angle <= angle <= end_angle:
                        # Draw a marker for split path section
                        p1 = self.outer_points[i]
                        p2 = self.inner_points[i]
                        pygame.draw.line(screen, (255, 255, 255), p1, p2, 1)

# Car physics and sensing
class Car:
    def __init__(self, x, y, angle=0, color=(255, 0, 0)):
        self.x = x
        self.y = y
        self.prev_x = x
        self.prev_y = y
        self.angle = angle
        self.velocity = 0
        self.steering_angle = 0
        self.sensor_count = 7
        self.sensor_range = 200
        self.laps = 0
        self.alive = True
        self.distance_traveled = 0
        self.time_alive = 0
        self.debug_info = []
        self.total_speed = 0
        self.color = color  # Car color, default is red
        
    def get_sensor_readings(self, track):
        readings = []
        sensor_endpoints = []
        
        for i in range(self.sensor_count):
            # Calculate sensor angle (spread sensors in front of car)
            sensor_angle = self.angle - math.pi/2 + i * math.pi/(self.sensor_count-1)
            
            # Ray casting to find distance to walls
            start = (self.x, self.y)
            end = (self.x + math.cos(sensor_angle) * self.sensor_range,
                   self.y + math.sin(sensor_angle) * self.sensor_range)
            
            closest_distance = self.sensor_range
            closest_point = end
            
            for segment in track.all_segments:
                intersection = self._line_intersection(start, end, segment[0], segment[1])
                if intersection:
                    distance = math.sqrt((intersection[0] - self.x)**2 + 
                                         (intersection[1] - self.y)**2)
                    if distance < closest_distance:
                        closest_distance = distance
                        closest_point = intersection
            
            readings.append(closest_distance)
            sensor_endpoints.append(closest_point)
                
        return readings, sensor_endpoints
    
    def _line_intersection(self, line1_start, line1_end, line2_start, line2_end):
        # Line-line intersection calculation
        x1, y1 = line1_start
        x2, y2 = line1_end
        x3, y3 = line2_start
        x4, y4 = line2_end
        
        denominator = (y4-y3)*(x2-x1) - (x4-x3)*(y2-y1)
        if denominator == 0:
            return None
        
        ua = ((x4-x3)*(y1-y3) - (y4-y3)*(x1-x3)) / denominator
        ub = ((x2-x1)*(y1-y3) - (y2-y1)*(x1-x3)) / denominator
        
        if 0 <= ua <= 1 and 0 <= ub <= 1:
            x = x1 + ua * (x2-x1)
            y = y1 + ua * (y2-y1)
            return (x, y)
        
        return None
    
    def update(self, acceleration, steering):
        # Save previous position for lap detection
        self.prev_x = self.x
        self.prev_y = self.y
        
        # Store previous velocity for logging
        prev_velocity = self.velocity
        
        # Apply acceleration (with reduced sensitivity)
        self.velocity += acceleration * 0.1  # Increased from 0.05
        
        # Add friction
        # self.velocity *= 0.95  # Decreased from 0.97 to allow better speed control
        
        # Start with low initial speed if stationary to help get moving
        if abs(self.velocity) < 0.1:
            self.velocity += 0.05 * (1 if acceleration > 0 else -1)
        
        # Clamp velocity to prevent extreme speeds
        self.velocity = max(0.5, min(5, self.velocity))
        
        # Print velocity debugging info occasionally
        # if random.random() < 0.01:  # Only print 1% of the time to avoid console spam
        #     print(f"Accel: {acceleration:.2f}, Velocity: {prev_velocity:.2f} → {self.velocity:.2f}, Friction effect: {prev_velocity * 0.95 - prev_velocity:.2f}")
        
        # Update steering angle
        self.steering_factor = 1.0 / (1.0 + abs(self.velocity))
        self.steering_angle += steering * self.steering_factor * 0.3
        
        # Gradually return steering to center
        self.steering_angle *= 0.9
        
        # Clamp steering angle
        self.steering_angle = max(-math.pi/3, min(math.pi/3, self.steering_angle))
        
        # Update angle based on steering and velocity
        self.angle += self.steering_angle * self.velocity * 0.03
        
        # Update position
        self.x += math.cos(self.angle) * self.velocity
        self.y += math.sin(self.angle) * self.velocity
        
        # Calculate distance traveled for fitness
        self.distance_traveled += abs(self.velocity)
        self.time_alive += 1
        
        # Track total speed for average speed calculation
        self.total_speed += abs(self.velocity)
        
        # Store debug info
        self.debug_info = [
            f"Velocity: {self.velocity:.2f}",
            f"Steering: {self.steering_angle:.2f}",
            f"Acc Input: {acceleration:.2f}",
            f"Steer Input: {steering:.2f}"
        ]
    
    def check_collision(self, track):
        # Collision detection with track walls
        car_radius = 8
        
        for segment in track.all_segments:
            p1, p2 = segment
            
            # Find closest point on line segment to car
            line_vec = (p2[0] - p1[0], p2[1] - p1[1])
            line_len = math.sqrt(line_vec[0]**2 + line_vec[1]**2)
            
            if line_len == 0:
                continue
                
            line_unitvec = (line_vec[0] / line_len, line_vec[1] / line_len)
            car_vec = (self.x - p1[0], self.y - p1[1])
            projection = (car_vec[0] * line_unitvec[0] + 
                          car_vec[1] * line_unitvec[1])
            
            closest_point = (
                p1[0] + line_unitvec[0] * max(0, min(line_len, projection)),
                p1[1] + line_unitvec[1] * max(0, min(line_len, projection))
            )
            
            distance = math.sqrt((self.x - closest_point[0])**2 + 
                                (self.y - closest_point[1])**2)
            
            if distance < car_radius:
                self.alive = False
                return True
        
        return False
    
    def check_lap(self, track):
        # Check if car crosses finish line in correct direction
        if self._line_intersection((self.prev_x, self.prev_y), 
                                  (self.x, self.y), 
                                  track.start_finish[0], 
                                  track.start_finish[1]):
            v1 = (self.x - self.prev_x, self.y - self.prev_y)
            v2 = (track.start_finish[1][0] - track.start_finish[0][0], 
                  track.start_finish[1][1] - track.start_finish[0][1])
            
            cross_product = v1[0] * v2[1] - v1[1] * v2[0]
            
            if cross_product > 0:
                self.laps += 1
                return True
        
        return False
    
    def draw(self, screen, sensor_endpoints=None, show_sensors=True):
        # Draw car with the specified color
        car_color = self.color
        car_length = 20
        car_width = 10
        
        # Calculate corner points of car
        corners = [
            (self.x + math.cos(self.angle) * car_length/2 - 
             math.sin(self.angle) * car_width/2,
             self.y + math.sin(self.angle) * car_length/2 + 
             math.cos(self.angle) * car_width/2),
            (self.x + math.cos(self.angle) * car_length/2 + 
             math.sin(self.angle) * car_width/2,
             self.y + math.sin(self.angle) * car_length/2 - 
             math.cos(self.angle) * car_width/2),
            (self.x - math.cos(self.angle) * car_length/2 + 
             math.sin(self.angle) * car_width/2,
             self.y - math.sin(self.angle) * car_length/2 - 
             math.cos(self.angle) * car_width/2),
            (self.x - math.cos(self.angle) * car_length/2 - 
             math.sin(self.angle) * car_width/2,
             self.y - math.sin(self.angle) * car_length/2 + 
             math.cos(self.angle) * car_width/2)
        ]
        
        pygame.draw.polygon(screen, car_color, corners)
        
        # Draw direction indicator
        front_x = self.x + math.cos(self.angle) * car_length/2
        front_y = self.y + math.sin(self.angle) * car_length/2
        pygame.draw.line(screen, (255, 255, 0), (self.x, self.y), (front_x, front_y), 2)
        
        # Draw sensors if requested and available
        if show_sensors and sensor_endpoints:
            for endpoint in sensor_endpoints:
                pygame.draw.line(screen, (255, 255, 255), 
                                (self.x, self.y), endpoint, 1)

# Polynomial controller
class PolynomialController:
    def __init__(self, input_size):
        self.input_size = input_size
        self.sensor_count = input_size - 2  # Last 2 inputs are steering and velocity
        
        # Initialize random coefficients for acceleration polynomial
        self.acceleration_coeffs = np.random.uniform(-1, 1, input_size + 1)
        
        # Initialize random coefficients for steering polynomial
        # self.steering_coeffs = np.random.uniform(-1, 1, input_size + 1)
        self.steering_coeffs = np.array([0, -1, -0.7, -0.3, 0, 0.3, 0.7, 1.0, 0, 0])

        
    
    def predict(self, sensor_readings, current_steering, current_velocity):
        # Normalize readings
        normalized_readings = [reading / 200 for reading in sensor_readings]
        
        # Add car state inputs (normalized)
        normalized_inputs = normalized_readings.copy()
        normalized_inputs.append(current_steering / (math.pi/3))  # Normalize steering angle
        normalized_inputs.append(current_velocity / 5.0)          # Normalize velocity
        
        # Calculate polynomial outputs
        acceleration = self.acceleration_coeffs[0]  # bias term
        steering = self.steering_coeffs[0]  # bias term
        
        for i, input_val in enumerate(normalized_inputs):
            acceleration += self.acceleration_coeffs[i+1] * input_val
            steering += self.steering_coeffs[i+1] * input_val
        
        # Clamp outputs
        acceleration = max(-5, min(5, acceleration))
        steering = max(-5, min(5, steering))
        
        return acceleration, steering
    
    def mutate(self, num_change, mutation_amount):
        # Create new controller with slightly modified coefficients

        print(f"Mutating with {num_change} changes and mutation amount {mutation_amount}")
        new_controller = PolynomialController(self.input_size)

        to_change = random.sample(list(range(20)), num_change)
        change_index = 0

        # Mutate acceleration coefficients
        new_controller.acceleration_coeffs = self.acceleration_coeffs.copy()
        for i in range(len(new_controller.acceleration_coeffs)):
            if change_index in to_change:
                new_controller.acceleration_coeffs[i] *= (1 + random.uniform(
                    -mutation_amount, mutation_amount))
            change_index += 1
        
        # Mutate steering coefficients
        new_controller.steering_coeffs = self.steering_coeffs.copy()
        for i in range(len(new_controller.steering_coeffs)):
            if change_index in to_change:
                new_controller.steering_coeffs[i] *= (1 + random.uniform(
                    -mutation_amount, mutation_amount))
            change_index += 1

        print(f"Old: {self.acceleration_coeffs}]\nNew: {new_controller.acceleration_coeffs}")
        print(f"Old: {self.steering_coeffs}\nNew: {new_controller.steering_coeffs}")
        print("====")
        
        return new_controller

# Hill climbing optimization
class HillClimber:
    def __init__(self, track, start_pos, start_angle, input_size=7):
        self.track = track
        self.start_pos = start_pos
        self.start_angle = start_angle
        self.input_size = input_size
        
        # Initialize with safer default controller
        self.best_controller = self.create_default_controller()
        
        # Evaluate the default controller
        self.best_fitness = self.evaluate(self.best_controller)
        print(f"Initial fitness: {self.best_fitness}")

        self.attempts_without_improvement = 0
        self.mutation_index = 0
        self.mutation_parameters = [
            (num_change, amount_change) for num_change in range(1, 21) for amount_change in [0.01, 0.03, 0.05, 0.07, 0.09, 0.12, 0.16, 0.2, 0.3, 0.4, 0.5, 0.6, 0.8, 1.0, 1.2, 1.4, 1.6, 1.8, 2.0]
        ]
    
    def create_default_controller(self):
        """Create a default controller with truly random parameters"""
        # Input size is sensor count + 2 (for steering and velocity)
        controller = PolynomialController(self.input_size + 2)
        
        # PRESETS

        # Initialize acceleration coefficients with small random values
        # Range from -0.1 to 0.1 for all coefficients including bias
        # controller.acceleration_coeffs = np.random.uniform(-0.1, 0.3, self.input_size + 3)
        controller.acceleration_coeffs = np.array([2.97639394e-03,8.51476912e-03,1.67934095e-03,1.11191525e-02,-2.38833541e-03,1.78328601e-03,-1.55906629e-03,4.80024719e-01,-2.70270019e-02,-2.17413392e+00])
        
        # Initialize steering coefficients with small random values
        # Range from -0.1 to 0.1 for all coefficients including bias
        # controller.steering_coeffs = np.random.uniform(-0.1, 0.1, self.input_size + 3)
        # controller.steering_coeffs = np.array([0, -1, -0.7, -0.3, 0, 0.3, 0.7, 1.0, 0, 0])
        controller.steering_coeffs = np.array([0.,-4.87792359,-0.43409822,-0.22670907,0.,2.38510963,4.60487316,0.86751384,0.,0.])


        

        
        return controller
    
    def evaluate(self, controller, max_steps=1000, render=False, debug=False, best_controller=None):
        global FAST_MODE
        
        # Create the test car (red)
        test_car = Car(self.start_pos[0], self.start_pos[1], self.start_angle, color=(255, 0, 0))
        
        # Create the best car (green) if a best controller is provided
        best_car = None
        if best_controller:
            best_car = Car(self.start_pos[0], self.start_pos[1], self.start_angle, color=(0, 255, 0))
        
        # Variables to track current control values
        current_acceleration = 0
        current_steering = 0
        
        if render:
            screen = pygame.display.get_surface()
            screen_width, screen_height = screen.get_size()
            clock = pygame.time.Clock()
            
            # For debugging
            font = pygame.font.Font(None, 24)
            small_font = pygame.font.Font(None, 16)
        
        steps = 0
        
        # Continue until both cars have crashed or max steps is reached
        while steps < max_steps and (test_car.alive or (best_car and best_car.alive)):
            # Process test car
            test_readings, test_sensor_endpoints = test_car.get_sensor_readings(self.track)
            
            if test_car.alive:
                # Get control outputs for test car
                test_acceleration, test_steering = controller.predict(
                    test_readings, 
                    test_car.steering_angle,
                    test_car.velocity
                )
                
                # Save current control values
                if test_car.alive:
                    current_acceleration = test_acceleration
                    current_steering = test_steering
                
                # Update test car
                test_car.update(test_acceleration, test_steering)
                
                # Check for collision
                test_collision = test_car.check_collision(self.track)
                if test_collision and debug and render:
                    print(f"Test car crashed at step {steps}")
                
                # Check for lap completion
                test_lap_completed = test_car.check_lap(self.track)
                if test_lap_completed and debug and render:
                    print(f"Test car completed lap at step {steps}")
            
            # Process best car if it exists
            best_readings = None
            best_sensor_endpoints = None
            
            if best_car and best_car.alive and best_controller:
                # Get sensor readings for best car
                best_readings, best_sensor_endpoints = best_car.get_sensor_readings(self.track)
                
                # Get control outputs for best car
                best_acceleration, best_steering = best_controller.predict(
                    best_readings, 
                    best_car.steering_angle,
                    best_car.velocity
                )
                
                # Update best car
                best_car.update(best_acceleration, best_steering)
                
                # Check for collision
                best_collision = best_car.check_collision(self.track)
                if best_collision and debug and render:
                    print(f"Best car crashed at step {steps}")
                
                # Check for lap completion
                best_lap_completed = best_car.check_lap(self.track)
                if best_lap_completed and debug and render:
                    print(f"Best car completed lap at step {steps}")
            
            # Render if needed
            if render and steps % 1 == 0:
                # Check for F key presses during simulation
                for event in pygame.event.get():
                    if event.type == pygame.QUIT:
                        pygame.quit()
                        return 0
                    elif event.type == pygame.KEYDOWN and event.key == pygame.K_f:
                        # Toggle fast mode
                        global FAST_MODE
                        FAST_MODE = not FAST_MODE
                        print(f"Fast mode {'ON - FPS limit removed' if FAST_MODE else 'OFF - FPS limited to 60'}")
                
                screen.fill((0, 0, 0))
                self.track.draw(screen)
                
                # Draw cars with proper z-ordering (best car behind test car)
                if best_car:
                    best_car.draw(screen, best_sensor_endpoints, show_sensors=False)
                
                test_car.draw(screen, test_sensor_endpoints, show_sensors=True)
                
                # Display info for test car
                info_y = 10
                texts = [
                    f"Test Car (Red) - Speed: {test_car.velocity:.2f}",
                    f"Steering: {current_steering:.2f}",
                    f"Acceleration: {current_acceleration:.2f}",
                    f"Turn ability: {test_car.steering_factor:.2f}",
                    f"Laps: {test_car.laps}",
                    f"Distance: {test_car.distance_traveled:.0f}",
                    f"Step: {steps}/{max_steps}",
                    f"Fitness: {self._calculate_fitness(test_car):.2f}",
                    f"Difficulty: {self.track.difficulty.upper()}"
                ]
                
                # Add info for best car if it exists
                if best_car:
                    texts.extend([
                        "",  # Empty line for spacing
                        f"Best Car (Green) - Speed: {best_car.velocity:.2f}",
                        f"Laps: {best_car.laps}",
                        f"Distance: {best_car.distance_traveled:.0f}",
                        f"Fitness: {self._calculate_fitness(best_car):.2f}",
                    ])
                
                # Add debug info if enabled
                if debug:
                    texts.extend(test_car.debug_info)
                
                # Get text color based on difficulty for the difficulty display
                difficulty_color = (255, 255, 255)  # Default white
                if hasattr(self.track, 'difficulty'):
                    if self.track.difficulty == "easy":
                        difficulty_color = (0, 255, 0)  # Green
                    elif self.track.difficulty == "medium":
                        difficulty_color = (255, 165, 0)  # Orange
                    elif self.track.difficulty == "hard":
                        difficulty_color = (255, 0, 0)  # Red
                
                for i, text in enumerate(texts):
                    # Special color for difficulty text
                    if text.startswith("Difficulty: "):
                        text_surface = font.render(text, True, difficulty_color)
                    elif text.startswith("Test Car"):
                        text_surface = font.render(text, True, (255, 0, 0))  # Red text for test car
                    elif text.startswith("Best Car"):
                        text_surface = font.render(text, True, (0, 255, 0))  # Green text for best car
                    else:
                        text_surface = font.render(text, True, (255, 255, 255))
                    screen.blit(text_surface, (10, info_y))
                    info_y += 25
                
                # Display current controller info in corner
                if debug:
                    # Show iteration info if available
                    if hasattr(controller, 'acceleration_coeffs') and hasattr(controller, 'steering_coeffs'):
                        # Display parameters in a more compact form
                        acc_text = "Acc: "
                        steer_text = "Steer: "
                        
                        for j, coeff in enumerate(controller.acceleration_coeffs):
                            if j == 0:
                                acc_text += f"B={coeff:.2f} "
                            else:
                                acc_text += f"S{j}={coeff:.2f} "
                        
                        for j, coeff in enumerate(controller.steering_coeffs):
                            if j == 0:
                                steer_text += f"B={coeff:.2f} "
                            else:
                                steer_text += f"S{j}={coeff:.2f} "
                        
                        acc_surface = small_font.render(acc_text, True, (255, 200, 200))
                        steer_surface = small_font.render(steer_text, True, (200, 200, 255))
                        
                        screen.blit(acc_surface, (10, screen_height - 40))
                        screen.blit(steer_surface, (10, screen_height - 20))
                
                # Show fast mode indicator
                speed_indicator = font.render("FAST MODE" if FAST_MODE else "", True, (255, 255, 0))
                screen.blit(speed_indicator, (screen_width - speed_indicator.get_width() - 10, 10))
                
                pygame.display.flip()
                
                # Cap framerate based on fast mode
                if FAST_MODE:
                    clock.tick()  # No FPS limit in fast mode
                else:
                    clock.tick(120)  # Normal 60 FPS limit
            
            steps += 1
        
        # Calculate fitness of the test car
        fitness = self._calculate_fitness(test_car)
        
        if render and debug:
            if not test_car.alive:
                print(f"Test car crashed at step {steps}")
            if best_car and not best_car.alive:
                print(f"Best car crashed at step {steps}")
            elif steps >= max_steps:
                print(f"Cars reached step limit ({max_steps})")
        
        # If car reached the step limit (still alive), return that information
        reached_limit = steps >= max_steps and (test_car.alive or (best_car and best_car.alive))
        
        return (fitness, reached_limit) if render else fitness
    
    def _calculate_fitness(self, car):
        # Base fitness is distance traveled
        fitness = car.distance_traveled
        
        # Calculate average speed (if car has been alive for some time)
        if car.time_alive > 0:
            avg_speed = car.total_speed / car.time_alive
            
            # Strong bonus for average speed (increased weight)
            speed_bonus = avg_speed * 10  # Higher weight on average speed
            fitness += speed_bonus
            
            # Add info about average speed to debug info
            if len(car.debug_info) <= 4:
                car.debug_info.append(f"Avg Speed: {avg_speed:.2f}")
        
        # Bonus for staying alive
        fitness += car.time_alive * 0.1
        
        return fitness

# Calculate dynamic starting position based on track
def calculate_starting_position(track):
    """Calculate a good starting position inside the track"""
    # Find the middle of the start/finish line
    start_x = (track.start_finish[0][0] + track.start_finish[1][0]) / 2
    start_y = (track.start_finish[0][1] + track.start_finish[1][1]) / 2
    
    # Calculate the finish line vector
    finish_line_vector = (
        track.start_finish[1][0] - track.start_finish[0][0],
        track.start_finish[1][1] - track.start_finish[0][1]
    )
    
    # Calculate finish line length and unit vector
    finish_line_length = math.sqrt(finish_line_vector[0]**2 + finish_line_vector[1]**2)
    finish_line_unit = (
        finish_line_vector[0] / finish_line_length,
        finish_line_vector[1] / finish_line_length
    )
    
    # Calculate perpendicular vector pointing inward (90 degrees clockwise)
    # This ensures we're always pointing into the track
    perp_vector = (-finish_line_unit[1], finish_line_unit[0])
    
    # Set start position inward from middle of start line
    # 30 pixels is a good distance to ensure we're inside the track
    start_pos = (
        start_x + perp_vector[0] * 30,
        start_y + perp_vector[1] * 30
    )
    
    # Calculate angle perpendicular to start line (pointing into track)
    start_angle = math.atan2(perp_vector[1], perp_vector[0])
    
    return start_pos, start_angle

# Create a more complex track with challenging features based on difficulty
def create_track(difficulty="random"):
    width, height = 800, 600
    center_x, center_y = width // 2, height // 2
    
    # Randomly choose difficulty if not specified
    if difficulty == "random":
        difficulty = random.choice(["easy", "medium", "hard"])
    
    # print(f"Creating {difficulty} track...")
    
    # Control points for track
    outer_points = []
    inner_points = []
    
    # Number of points (higher for more complex tracks)
    n_points = 60
    if difficulty == "hard":
        n_points = 80  # More points for higher resolution on complex tracks
    
    # Set track parameters based on difficulty
    if difficulty == "easy":
        # Easy track - wider, gentler curves, more regular
        base_radius_outer = random.uniform(250, 280)
        base_radius_inner = random.uniform(170, 200)
        min_track_width = 80  # Wider track is easier to navigate
        
        # Gentle features
        sharp_turn_amp = random.uniform(15, 30)
        sharp_turn_freq = random.uniform(1.5, 2.0)
        sharp_turn_phase = random.uniform(0, 2 * math.pi)
        
        chicane_amp = random.uniform(10, 20)
        chicane_freq = random.uniform(3, 4)
        chicane_phase = random.uniform(0, 2 * math.pi)
        
        hairpin_amp = random.uniform(10, 20)
        hairpin_freq1 = random.uniform(1.2, 1.5)
        hairpin_freq2 = random.uniform(0.8, 1.0)
        hairpin_phase = random.uniform(0, 2 * math.pi)
        
        narrow_amp = random.uniform(10, 15)
        narrow_freq1 = random.uniform(0.8, 1.0)
        narrow_freq2 = random.uniform(0.5, 0.8)
        
        # Very regular track for easy difficulty
        right_side_factor = 0.6  # Less variation overall
        
    elif difficulty == "medium":
        # Medium track - moderate width, some challenges
        base_radius_outer = random.uniform(240, 270)
        base_radius_inner = random.uniform(150, 180)
        min_track_width = 70  # Medium track width
        
        # Moderate features
        sharp_turn_amp = random.uniform(25, 45)
        sharp_turn_freq = random.uniform(2.0, 2.5)
        sharp_turn_phase = random.uniform(0, 2 * math.pi)
        
        chicane_amp = random.uniform(15, 25)
        chicane_freq = random.uniform(4, 6)
        chicane_phase = random.uniform(0, 2 * math.pi)
        
        hairpin_amp = random.uniform(20, 30)
        hairpin_freq1 = random.uniform(1.6, 1.9)
        hairpin_freq2 = random.uniform(1.0, 1.5)
        hairpin_phase = random.uniform(0, 2 * math.pi)
        
        narrow_amp = random.uniform(15, 25)
        narrow_freq1 = random.uniform(0.9, 1.2)
        narrow_freq2 = random.uniform(0.6, 0.9)
        
        # Moderate track regularity
        right_side_factor = 0.4
        
    else:  # hard
        # Hard track - narrower in places, complex features
        base_radius_outer = random.uniform(230, 260)
        base_radius_inner = random.uniform(140, 170)
        min_track_width = 50  # Tighter track in places
        
        # Challenging features
        sharp_turn_amp = random.uniform(40, 60)
        sharp_turn_freq = random.uniform(2.5, 3.2)
        sharp_turn_phase = random.uniform(0, 2 * math.pi)
        
        chicane_amp = random.uniform(25, 40)
        chicane_freq = random.uniform(6, 8)
        chicane_phase = random.uniform(0, 2 * math.pi)
        
        hairpin_amp = random.uniform(35, 50)
        hairpin_freq1 = random.uniform(1.9, 2.2)
        hairpin_freq2 = random.uniform(1.4, 1.7)
        hairpin_phase = random.uniform(0, 2 * math.pi)
        
        narrow_amp = random.uniform(30, 45)
        narrow_freq1 = random.uniform(1.2, 1.6)
        narrow_freq2 = random.uniform(0.8, 1.2)
        
        # Less regular track for hard difficulty
        right_side_factor = 0.3
    
    # Add S-curves for medium and hard difficulties
    s_curve_amp = 0
    s_curve_freq = 0
    s_curve_phase = 0
    if difficulty == "medium":
        s_curve_amp = random.uniform(15, 25)
        s_curve_freq = random.uniform(4, 5)
        s_curve_phase = random.uniform(0, 2 * math.pi)
    elif difficulty == "hard":
        s_curve_amp = random.uniform(30, 40)
        s_curve_freq = random.uniform(5, 7)
        s_curve_phase = random.uniform(0, 2 * math.pi)
    
    # Add decreasing radius turns for hard difficulty
    decreasing_radius_effect = 0
    decreasing_radius_freq = 0
    decreasing_radius_phase = 0
    if difficulty == "hard":
        decreasing_radius_effect = random.uniform(0.1, 0.2)
        decreasing_radius_freq = random.uniform(2, 3)
        decreasing_radius_phase = random.uniform(0, 2 * math.pi)
    
    # Special track features - only for medium and hard difficulties
    special_features = []
    
    # Figure 8 crossover (hard only)
    has_figure_8 = False
    figure_8_position = 0
    if difficulty == "hard" and random.random() < 0.5:  # 50% chance for hard tracks
        has_figure_8 = True
        figure_8_position = random.uniform(0, 2 * math.pi)
    
    # Split path that rejoins (medium and hard)
    has_split_path = False
    split_start = 0
    split_end = 0
    if difficulty in ["medium", "hard"] and random.random() < 0.3:  # 30% chance
        has_split_path = True
        split_start = random.uniform(0, math.pi)
        split_length = random.uniform(math.pi/4, math.pi/2)
        split_end = split_start + split_length
    
    # Variable track width (more extreme in hard)
    width_variation_amp = 0
    if difficulty == "medium":
        width_variation_amp = random.uniform(5, 15)
    elif difficulty == "hard":
        width_variation_amp = random.uniform(15, 30)
    
    # Generate track points
    for i in range(n_points):
        angle = 2 * math.pi * i / n_points
        
        # Reduce feature intensity on the right side (where car starts)
        side_factor = right_side_factor + (1 - right_side_factor) * (1 - math.cos(angle)) / 2
        
        # Add complexity with multiple sine waves of different frequencies and phases
        # Sharp turns with higher amplitude components
        sharp_turn = sharp_turn_amp * side_factor * math.sin(angle * sharp_turn_freq + sharp_turn_phase) 
        
        # Chicanes with higher frequency components
        chicane = chicane_amp * side_factor * math.sin(angle * chicane_freq + chicane_phase)
        
        # S-curves for medium/hard
        s_curve = 0
        if difficulty in ["medium", "hard"]:
            s_curve = s_curve_amp * side_factor * math.sin(angle * s_curve_freq + s_curve_phase) * math.cos(angle * (s_curve_freq/2))
        
        # Additional complexity with asymmetric features
        hairpin_turn = hairpin_amp * side_factor * math.sin(angle * hairpin_freq1 - hairpin_phase) * math.sin(angle * hairpin_freq2)
        
        # Narrow section in one part of the track
        narrow_section = narrow_amp * side_factor * math.sin(angle * narrow_freq1) * math.cos(angle * narrow_freq2)
        
        # Decreasing radius turns for hard
        decreasing_radius = 0
        if difficulty == "hard":
            # Make certain turns get progressively tighter
            decreasing_radius = decreasing_radius_effect * math.sin(angle * decreasing_radius_freq + decreasing_radius_phase)
            decreasing_radius = decreasing_radius * decreasing_radius * 100  # Square it for more pronounced effect
        
        # Figure 8 effect - create a crossover point for hard tracks
        figure_8_effect = 0
        if has_figure_8:
            # Create a "pinch" in the track at a specific position
            figure_8_delta = abs(angle - figure_8_position) % (2 * math.pi)
            if figure_8_delta < 0.3 or figure_8_delta > (2 * math.pi - 0.3):
                figure_8_effect = -30  # Pinch the track inward at the crossover point
        
        # Combine all features
        r_outer = base_radius_outer + sharp_turn + chicane + s_curve + hairpin_turn - abs(narrow_section) * 0.5 + figure_8_effect
        # Make inner track follow but with less extreme features for drivability
        r_inner = base_radius_inner + sharp_turn * 0.7 + chicane * 0.6 + s_curve * 0.7 + hairpin_turn * 0.7 + narrow_section * 0.3 + figure_8_effect
        
        # Apply decreasing radius to both curves if present
        if decreasing_radius > 0:
            r_outer -= decreasing_radius
            r_inner -= decreasing_radius * 0.7
        
        # Variable track width
        if width_variation_amp > 0:
            width_variation = width_variation_amp * math.sin(angle * 3.7 + 1.2) * math.sin(angle * 2.3)
            r_inner += width_variation  # Adjust inner radius to create width variation
        
        # Make sure inner track doesn't cross outer track and maintains minimum width
        if r_inner > r_outer - min_track_width:
            r_inner = r_outer - min_track_width
        
        outer_points.append((
            center_x + r_outer * math.cos(angle),
            center_y + r_outer * math.sin(angle)
        ))
        
        inner_points.append((
            center_x + r_inner * math.cos(angle),
            center_y + r_inner * math.sin(angle)
        ))
    
    # Special case for split paths (currently just visual markers, not actual splits)
    if has_split_path:
        # Mark the split path section - future enhancement could implement actual splits
        # For now we'll just change the track color in this section
        special_features.append(("split_path", split_start, split_end))
    
    # Smooth the track to avoid extremely tight corners (2-pass smoothing)
    outer_points = smooth_track(outer_points)
    inner_points = smooth_track(inner_points)
    
    # Create the track with difficulty info
    track = Track(outer_points, inner_points)
    track.difficulty = difficulty  # Store the difficulty for display
    track.special_features = special_features  # Store any special features
    
    return track

def smooth_track(points):
    """Apply smoothing to track points to avoid tight corners"""
    smoothed = []
    n = len(points)
    
    # Apply a simple moving average
    for i in range(n):
        prev_idx = (i - 1) % n
        next_idx = (i + 1) % n
        
        # Average position with neighbors
        x = (points[prev_idx][0] + points[i][0] + points[next_idx][0]) / 3
        y = (points[prev_idx][1] + points[i][1] + points[next_idx][1]) / 3
        
        smoothed.append((x, y))
    
    return smoothed

def wait_for_key():
    print("Waiting for key press...")
    start_time = time.time()
    while True:
        current_time = time.time()
        if current_time - start_time > 0.5:  # Print a message every 0.5 seconds
            print("Still waiting for key press...")
            start_time = current_time
            
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                print("Quit event received during wait_for_key")
                pygame.quit()
                sys.exit()
            if event.type == pygame.KEYDOWN:
                print(f"Key pressed: {pygame.key.name(event.key)}")
                return
            if event.type == pygame.MOUSEBUTTONDOWN:
                print(f"Mouse button pressed: {event.button}")
                return
        pygame.time.wait()  # Short delay to avoid hogging CPU

def handle_events():
    """Process events and return if the program should quit"""
    global FAST_MODE
    
    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            return True
        elif event.type == pygame.KEYDOWN:
            if event.key == pygame.K_ESCAPE:
                return True
            elif event.key == pygame.K_f:
                # Toggle fast mode
                FAST_MODE = not FAST_MODE
                print(f"Fast mode {'ON - FPS limit removed' if FAST_MODE else 'OFF - FPS limited to 60'}")
    
    return False

def main():
    try:
        print("Starting self-driving car simulator...")
        # Initialize pygame
        pygame.init()
        width, height = 800, 600
        screen = pygame.display.set_mode((width, height))
        pygame.display.set_caption("Self-Driving Car with Hill Climbing")
        clock = pygame.time.Clock()
        
        # Initialize fonts
        font = pygame.font.Font(None, 36)
        small_font = pygame.font.Font(None, 20)
        
        # CONTINUOUS TRAINING WITH TRACK REGENERATION
        iteration_counter = 0
        track_counter = 0
        best_fitness_ever = 0
        current_step_limit = 4000  # Starting with 300 steps, will increase over time
        
        # Track all-time best controller
        all_time_best_controller = None
        
        # Create the first track
        track = create_track(TRACK_DIFFICULTY)
        track_counter += 1
        
        # Calculate dynamic starting position
        start_pos, start_angle = calculate_starting_position(track)
        
        # Create hill climber with dynamic starting position and increased input size (7+2)
        # The 7 original sensor inputs + steering angle + velocity
        hill_climber = HillClimber(track, start_pos, start_angle, input_size=7)
        
        # Store the initial controller as the all-time best
        all_time_best_controller = copy.deepcopy(hill_climber.best_controller)
        
        # Using an infinite loop to run forever
        while True:
            # Create a new track every iteration with random difficulty
            # track = create_track(TRACK_DIFFICULTY)
            track_counter += 1
            
            # Calculate new dynamic starting position for this track
            start_pos, start_angle = calculate_starting_position(track)
            
            # Update the hill climber's track and starting position
            hill_climber.track = track
            hill_climber.start_pos = start_pos
            hill_climber.start_angle = start_angle
            
            # Show an overview of the current training state
            screen.fill((0, 0, 0))
            track.draw(screen)
            overlay = pygame.Surface((width, height), pygame.SRCALPHA)
            overlay.fill((0, 0, 0, 128))
            screen.blit(overlay, (0, 0))
            
            # Visualize the starting position
            pygame.draw.circle(screen, (0, 255, 0), (int(start_pos[0]), int(start_pos[1])), 8)
            # Draw a line showing the starting angle
            line_end_x = start_pos[0] + math.cos(start_angle) * 30
            line_end_y = start_pos[1] + math.sin(start_angle) * 30
            pygame.draw.line(screen, (0, 255, 0), start_pos, (line_end_x, line_end_y), 2)
            
            # Display training info
            info_text = font.render(f"Track #{track_counter}, Iteration #{iteration_counter+1}", 
                                   True, (255, 255, 255))
            fitness_text = font.render(f"Best Fitness: {hill_climber.best_fitness:.2f} | Step Limit: {current_step_limit}", 
                                      True, (255, 255, 255))
            
            # Color-code difficulty text based on level
            difficulty_color = (255, 255, 255)
            if track.difficulty == "easy":
                difficulty_color = (0, 255, 0)  # Green
            elif track.difficulty == "medium":
                difficulty_color = (255, 165, 0)  # Orange
            elif track.difficulty == "hard":
                difficulty_color = (255, 0, 0)  # Red
                
            difficulty_text = font.render(f"Difficulty: {track.difficulty.upper()}", 
                                         True, difficulty_color)
                                         
            speed_mode = "FAST MODE ENABLED" if FAST_MODE else "Normal Speed"
            speed_text = font.render(f"Speed: {speed_mode}", True, (255, 255, 0) if FAST_MODE else (200, 200, 200))
            
            if all_time_best_controller:
                all_time_text = font.render(f"All-time Best Fitness: {best_fitness_ever:.2f}", True, (0, 255, 0))
                screen.blit(all_time_text, (width//2 - all_time_text.get_width()//2, 190))
            
            continue_text = small_font.render("Press ESC to exit, F to toggle fast mode, any other key to jump ahead 20 iterations", 
                                            True, (200, 200, 200))
            
            screen.blit(info_text, (width//2 - info_text.get_width()//2, 30))
            screen.blit(fitness_text, (width//2 - fitness_text.get_width()//2, 70))
            screen.blit(difficulty_text, (width//2 - difficulty_text.get_width()//2, 110))
            screen.blit(speed_text, (width//2 - speed_text.get_width()//2, 150))
            screen.blit(continue_text, (width//2 - continue_text.get_width()//2, height - 30))
            
            pygame.display.flip()
            
            # Check for events (F key for fast mode)
            should_quit = handle_events()
            if should_quit:
                pygame.quit()
                sys.exit()
            
            # Create a mutated controller
            mutated_controller = hill_climber.best_controller.mutate(*hill_climber.mutation_parameters[hill_climber.mutation_index])
            
            # Evaluate controller with both cars running
            result = hill_climber.evaluate(
                mutated_controller, 
                max_steps=current_step_limit, 
                render=True, 
                debug=False,
                best_controller=all_time_best_controller  # Add the all-time best controller for dual simulation
            )
            
            # Check if the car reached the step limit - if so, increase the limit
            if isinstance(result, tuple) and len(result) == 2:
                fitness, reached_limit = result
                if reached_limit and fitness > 0.75 * hill_climber.best_fitness:
                    current_step_limit += 100
                    print(f"Car reached step limit! Increasing to {current_step_limit} steps")
            else:
                fitness = result
                reached_limit = False
            
            # If better, keep it
            improvement = False
            if fitness > hill_climber.best_fitness:
                hill_climber.attempts_without_improvement = 0
                hill_climber.mutation_index = 0  # Reset mutation index on improvement
                
                improvement = True
                
                hill_climber.best_controller = mutated_controller
                hill_climber.best_fitness = fitness
                print(f"Iteration {iteration_counter+1}: New best fitness: {hill_climber.best_fitness}")
                
                # Print the coefficients ONLY when there's an improvement
                acc_coeffs = [f"{coeff:.4f}" for coeff in mutated_controller.acceleration_coeffs]
                steer_coeffs = [f"{coeff:.4f}" for coeff in mutated_controller.steering_coeffs]
                print(f"New acc coeffs: {', '.join(acc_coeffs)}")
                print(f"New steer coeffs: {', '.join(steer_coeffs)}")
                
                # Update the all-time best fitness and controller if necessary
                if fitness > best_fitness_ever:
                    best_fitness_ever = fitness
                    all_time_best_controller = copy.deepcopy(mutated_controller)
                    print(f"New all-time best fitness: {best_fitness_ever}!")
            else:
                hill_climber.attempts_without_improvement += 1
                if hill_climber.attempts_without_improvement >= 1:
                    hill_climber.attempts_without_improvement = 0
                    hill_climber.mutation_index = (hill_climber.mutation_index + 1) % len(hill_climber.mutation_parameters)
            
            # Show brief result flash
            screen.fill((0, 0, 0))
            track.draw(screen)
            overlay = pygame.Surface((width, height), pygame.SRCALPHA)
            overlay.fill((0, 0, 0, 128))
            screen.blit(overlay, (0, 0))
            
            if improvement:
                result_text = font.render(f"IMPROVED! New fitness: {fitness:.2f}", 
                                         True, (0, 255, 0))
            else:
                result_text = font.render(f"No improvement. Fitness: {fitness:.2f}", 
                                         True, (255, 100, 100))
            
            screen.blit(result_text, (width//2 - result_text.get_width()//2, height//2))
            pygame.display.flip()
            
            # Brief pause to see result
            if not FAST_MODE:
                pygame.time.wait(500)  # Just half a second pause
            
            # Check for user input for jumping ahead
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    pygame.quit()
                    sys.exit()
                elif event.type == pygame.KEYDOWN:
                    if event.key == pygame.K_ESCAPE:
                        # Exit if user presses escape
                        print("User exited training loop")
                        pygame.quit()
                        sys.exit()
                    elif event.key == pygame.K_f:
                        # Toggle fast mode - handled in handle_events()
                        pass
                    elif event.key not in [pygame.K_f, pygame.K_ESCAPE]:
                        # Skip ahead 20 iterations if user presses any other key (except F or Escape)
                        iteration_counter += 19  # Will be incremented again below
                        print("Jumping ahead 20 iterations")
            
            # Increment counters
            iteration_counter += 1
        
    except Exception as e:
        print(f"Error occurred: {e}")
        traceback.print_exc()
        pygame.quit()
        sys.exit(1)

if __name__ == "__main__":
    try:
        # Try the full self-driving car simulation first
        main()
    except Exception as e:
        print(f"Main function failed with error: {e}")
        traceback.print_exc()
        print("\nTrying simple test instead...")
        # If it fails, try the simple test
        simple_test()