#include <stdlib.h>
#include <stdio.h>
#include <stdbool.h>
#include <unistd.h>
#include <stdint.h>

#include <pthread.h>
#include <semaphore.h>

#include "fonctions.h"


FILE *file_in;
FILE *file_out;
char *f_in;
char *f_out;
int nthreads = 1;
buffer1 *buffer_1;
buffer2 *buffer_2;
int nlines;


void *reading() {
    file_in = fopen(f_in, "r");
    if (file_in == NULL) {
        exit(EXIT_FAILURE);
    }

    file_out = fopen(f_out, "w");
    if (file_out == NULL) {
        exit(EXIT_FAILURE);
    }

    uint64_t number;
    while (fscanf(file_in, "%lu", &number) != EOF) {
        sem_wait(&(buffer_1->free));
        pthread_mutex_lock(&(buffer_1->mutex));
        put_node_t(buffer_1->numbers, number);
        pthread_mutex_unlock(&(buffer_1->mutex));
        sem_post(&(buffer_1->full));
    }

    pthread_mutex_lock(&(buffer_1->mutex));
    buffer_1->can_stop = true;
    pthread_mutex_unlock(&(buffer_1->mutex));

    return NULL;
}


void *calculating() {
    while (buffer_1->can_stop == false || buffer_1->numbers->length != 0) {
        uint64_t number;

        sem_wait(&(buffer_1->full));
        pthread_mutex_lock(&(buffer_1->mutex));
        number = get_node_t(buffer_1->numbers);
        pthread_mutex_unlock(&(buffer_1->mutex));
        sem_post(&(buffer_1->free));

        list_t *list_prime = prime_list(number);

        sem_wait(&(buffer_2->free));
        pthread_mutex_lock(&(buffer_2->mutex));
        put_node_lst(buffer_2->prime_numbers, list_prime);
        pthread_mutex_unlock(&(buffer_2->mutex));
        sem_post(&(buffer_2->full));
    }

    return NULL;
}

void *writing() {
    int nlines = count_lines(f_in);
    while (nlines > 0) {
        nlines --;

        sem_wait(&(buffer_2->full));
        pthread_mutex_lock(&(buffer_2->mutex));
        list_t *removed = get_node_lst(buffer_2->prime_numbers);
        pthread_mutex_unlock(&(buffer_2->mutex));
        sem_post(&(buffer_2->free));

        int length = removed->length;
        for (int i = 0; i < length; i++) {
            fprintf(file_out, "%lu ", get_node_t(removed));
        }
        clear_list_t(removed);
        free(removed);
        fprintf(file_out, "\n");
    }

    return NULL;
}

int main(int argc, char *argv[]) {
    int opt;
    while ((opt = getopt(argc, argv, "N:")) != -1) {
        switch(opt) {
            case 'N':
                nthreads = atoi(optarg);
                break;
            default:
                break;
        }
        f_in = argv[optind];
        f_out = argv[optind + 1];
    }
    
    
    buffer_1 = init_buffer_1(nthreads);
    if (!buffer_1) {
        free(buffer_1);
        return -1;
    }

    buffer_2 = init_buffer_2(nthreads);
    if (!buffer_2) {
        free(buffer_2);
        return -1;
    }

    pthread_t read;
    pthread_t calc[nthreads];
    pthread_t write;

    if (pthread_create(&read, NULL, &reading, NULL) != 0) return -1;

    for (int i = 0; i < nthreads; i++) {
        if (pthread_create(&calc[i], NULL, &calculating, NULL) != 0) return -1;
    }

    if (pthread_create(&write, NULL, &writing, NULL) != 0) return -1;

    if (pthread_join(read, NULL) != 0) return -1;
    
    for (int i = 0; i < nthreads; i++) {
        pthread_cancel(calc[i]);
        if (pthread_join(calc[i], NULL) != 0) return -1;
    }

    if (pthread_join(write, NULL) != 0) return -1;

    fclose(file_in);
    fclose(file_out);

    clear_list_t(buffer_1->numbers);
    free(buffer_1->numbers);
    free(buffer_2->prime_numbers);

    free(buffer_1);
    free(buffer_2);

    return 0;
}