import os
import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint  # Lisätään tarvittavat callbackit
import datetime

print("Koulutuksen aloitus:", datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))

# Määritetään polut
TRAIN_DIR = './dataset/train'
VALID_DIR = './dataset/validation'
TEST_DIR = './dataset/test'

# Ladataan data käyttäen ImageDataGeneratoria
# Koulutusdatalle lisätään data augmentation -toiminnot
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

# Validaatiolle ja testille ei tehdä data augmentationia, vain rescaling
valid_datagen = ImageDataGenerator(rescale=1./255)

# Muuta target_size-arvoksi (224, 224)
train_data = train_datagen.flow_from_directory(
    TRAIN_DIR,
    target_size=(224, 224),
    class_mode='categorical'
)

print(train_data.class_indices)

valid_data = valid_datagen.flow_from_directory(
    VALID_DIR,
    target_size=(224, 224),
    class_mode='categorical'
)

# Muuta input_shape-arvoksi (224, 224, 3)
base_model = MobileNetV2(input_shape=(224, 224, 3), include_top=False, weights='imagenet')

# Jäädytetään esikoulutetun mallin kerrokset
for layer in base_model.layers:
    layer.trainable = False

# Lisätään omat kerrokset esikoulutetun mallin päälle
x = GlobalAveragePooling2D()(base_model.output)
x = Dense(128, activation='relu')(x)
predictions = Dense(10, activation='softmax')(x)

model = Model(inputs=base_model.input, outputs=predictions)

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# Määritetään callbacks
early_stopping = EarlyStopping(monitor='val_loss', patience=2, restore_best_weights=True)
checkpoint = ModelCheckpoint("best_MobileNetV2_military_vehicles_model.tf", monitor='val_loss', save_best_only=True)

# Koulutetaan malli
model.fit(train_data, validation_data=valid_data, epochs=100, callbacks=[early_stopping, checkpoint]) 

# Tallennetaan malli levylle .tf formaatissa
model.save('MobileNetV2_military_vehicles_model.tf', save_format='tf')

print("Koulutuksen lopetus:", datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))