# backprop example from the notes with
# no activation function, no hidden layers
# 3 inputs, 1 output, and THREE samples

def loss(w1, w2, w3, b):
    yhat1 = 2 * w1 + w2 + w3 + b
    yhat2 = -w1 + 2*w2 + 3*w3 + b
    yhat3 = -w1 - w2 - 5*w3 + b
    return 1/3 * ((yhat1 - 3)**2 + (yhat2 + 4)**2 + (yhat3 + 5)**2)

def one_step(w1, w2, w3, b):
    yhat = 1 * w1 + (-1) * w2 + 2 * w3 + b
    yhat1 = 2 * w1 + w2 + w3 + b
    yhat2 = -w1 + 2*w2 + 3*w3 + b
    yhat3 = -w1 - w2 - 5*w3 + b

    dl_dyhat1 = 2/3 * (yhat1 - 3)
    dl_dyhat2 = 2/3 * (yhat2 + 4)
    dl_dyhat3 = 2/3 * (yhat3 + 5)

    print(f"yhat: {yhat1:.2f} {yhat2:.2f} {yhat3:.2f}")
    print(f"dl_dyhat: {dl_dyhat1:.2f} {dl_dyhat2:.2f} {dl_dyhat3:.2f}")

    print(f"dw1: {2 * dl_dyhat1:.2f} {-1 * dl_dyhat2:.2f} {-1 * dl_dyhat3:.2f}")
    print(f"dw2: {1 * dl_dyhat1:.2f} {2 * dl_dyhat2:.2f} {-1 * dl_dyhat3:.2f}")
    print(f"dw3: {1 * dl_dyhat1:.2f} {3 * dl_dyhat2:.2f} {-5 * dl_dyhat3:.2f}")
    print(f"db: {1 * dl_dyhat1:.2f} {1 * dl_dyhat2:.2f} {1 * dl_dyhat3:.2f}")

    dl_dw1 = 2 * dl_dyhat1 + (-1) * dl_dyhat2 + (-1) * dl_dyhat3
    dl_dw2 = 1 * dl_dyhat1 + 2 * dl_dyhat2 + (-1) * dl_dyhat3
    dl_dw3 = 1 * dl_dyhat1 + 3 * dl_dyhat2 + (-5) * dl_dyhat3
    dl_db = 1 * dl_dyhat1 + 1 * dl_dyhat2 + 1 * dl_dyhat3

    print(f"gradient: {dl_dw1:.2f} {dl_dw2:.2f} {dl_dw3:.2f} {dl_db:.2f}")

    return (w1 - dl_dw1/100, w2 - dl_dw2/100, w3 - dl_dw3/100, b - dl_db/100)

def hundred_steps(w1, w2, w3, b):
    for i in range(100):
        (w1, w2, w3, b) = one_step(w1, w2, w3, b)
    return (w1, w2, w3, b)

(w1, w2, w3, b) = (2, -3, 1, 0.5)

