Ryanhub - file viewer
filename: train.c
branch: main
back to repo
// train.c
// training routine using MSE and standard backprop

#include "includes.c"

int main() {

    srand(time(NULL));

    Model m;
    Dataset d;

    init_model(&m);
    load_dataset("emojis.bin", &d);

    int steps = 3000; // training steps
	int epoch_to_print = 1000; // print loss and save image every

    float latent[LATENT];
    float output[IMAGESIZE];
    float d_output[IMAGESIZE];
    float d_latent[LATENT];
    float avg_loss = 0;

    for (int step = 0; step < steps; step++) {

        int idx = rand() % d.count;
        float *img = &d.images[idx * IMAGESIZE];

        // ---- forward ----
		// the easiest forward phase of all time :)
        encode(&m, img, latent);
        decode(&m, latent, output);

        // ---- loss + gradient ----
        float loss = 0;

        for (int i = 0; i < IMAGESIZE; i++) {
            float diff = output[i] - img[i];
            loss += diff * diff; // mean squared error

            float t = 2.0f * output[i] - 1.0f; // tanh(sum)
            float dout_dsum = 0.5f * (1.0f - t*t); // derivative of decoder output
            float dL_dout = 2.0f * diff; // derivative of squared error
            d_output[i] = dL_dout * dout_dsum; // now dL/dsum
			// d_output contains gradient of loss wrt sum
        }

        // ---- backprop decoder ----
        for (int j = 0; j < LATENT; j++)
            d_latent[j] = 0;

        for (int j = 0; j < LATENT; j++) { // for each latent component
            for (int i = 0; i < IMAGESIZE; i++) { // for each image component
                int w2_idx = j * IMAGESIZE + i;

                float grad = d_output[i] * latent[j]; // scale gradient(per pixel) by latent value
                float w = m.w2[w2_idx];

                d_latent[j] += d_output[i] * w; // save the contribution for each latent component to backprop encoder
                m.w2[w2_idx] -= LR * grad; // backprop it! (scaled by LR)
            }
        }

        for (int i = 0; i < IMAGESIZE; i++)
            m.b2[i] -= LR * d_output[i]; // also backprop in the weights...

        // ---- backprop encoder ----
        for (int j = 0; j < LATENT; j++) {
            float dz = d_latent[j] * (1 - latent[j] * latent[j]); // derivative for dl/dsum

            for (int i = 0; i < IMAGESIZE; i++) {
                int w1_idx = i * LATENT + j;
                float grad = dz * img[i]; // = dl/dw1
                m.w1[w1_idx] -= LR * grad; // backprop again!
            }

            m.b1[j] -= LR * dz; // and the weights too
        }

        avg_loss += loss;
        if (step % epoch_to_print == 0) {
			// print loss, on first dont divide by 1000
            if (step != 0) printf("step %d loss %f\n", step, avg_loss / epoch_to_print / IMAGESIZE);
            else printf("initial loss %f\n", loss / IMAGESIZE);
            avg_loss = 0;

			// save reconstruction so we can see improvements
            float image[IMAGESIZE];
            char filename[16];
            snprintf(filename, 16, "%d.bmp", step);
            reconstruct(&m, img, image);
            save_bmp(filename, image);
        }
    }

    save_model("model.bin", &m);

    return 0;
}