import pandas as pd
from transformers import pipeline
import numpy as np
import tensorflow_hub as hub
from sklearn.cluster import KMeans

def get_cluster_names_centroid(embeddings, labels, reasons):
    """ 
    Luodaan klustereille nimet perustuen lähimpään kohtaan klusterin keskipisteessä.
    """
    cluster_names = {}
    max_length = 64
    for cluster in set(labels):
        cluster_points = embeddings[labels == cluster]
        centroid = np.mean(cluster_points, axis=0)
        closest_point = np.argmin(np.linalg.norm(cluster_points - centroid, axis=1))
        reason_for_cluster = [reasons[i] for i, c in enumerate(labels) if c == cluster][closest_point]
        cluster_names[cluster] = reason_for_cluster.capitalize()[:max_length] + "..." if len(reason_for_cluster) > max_length else reason_for_cluster.capitalize()
    return cluster_names

# Ladataan data CSV-tiedostosta.
data = pd.read_csv("data-1693396993053.csv")
notes = data["notes"].tolist()

# Luodaan kysymys-vastaus pipeline (QnA), joka tunnistaa mielenosoitusten syyn.
qa_pipeline = pipeline("question-answering", model="mrm8488/spanbert-finetuned-squadv2", tokenizer="mrm8488/spanbert-finetuned-squadv2")
reasons = [qa_pipeline({"context": note, "question": "What is the reason for the protest?"})['answer'] for note in notes]

# Muunnetaan syyt numeerisiksi arvoiksi (embeddings) käyttämällä Google's Universal Sentence Encoderia.
embed = hub.load("https://tfhub.dev/google/universal-sentence-encoder/4")
embeddings = embed(reasons).numpy()

# Klusteroidaan data käyttämällä K-means -algoritmia.
data["cluster"] = KMeans(n_clusters=14, n_init=10).fit_predict(embeddings)

# Haetaan klustereille nimet.
cluster_names = get_cluster_names_centroid(embeddings, data["cluster"].to_numpy(), reasons)

# Tulostetaan klusterit ja niiden lukumäärät kuvaavilla nimillä, järjestyksessä suurimmasta pienimpään.
for cluster, count in data.groupby("cluster").size().sort_values(ascending=False).items():
    print(f"{cluster_names[cluster]}: {count}")
