# FILE: ti_example_with_pump_data.py

# torch imports
import torch
import torch.nn as nn
import torchinfo
import torchmetrics
import torchmetrics.regression

# ti, onnx imports
from tinyml_torchmodelopt.quantization import TINPUTinyMLQATFxModule, TINPUTinyMLPTQFxModule

# other imports
import numpy as np
from typing import Tuple, List

# --- NEW: Import your custom DataModule ---
from data import *

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# --- REMOVED: All original data loading functions (TorqueMeasurementDataset, get_dataset_from_csv, get_dataloader) ---

def metric(y_pred, y_target, metric_name):
    """Calculates a specified regression metric."""
    metric_instance = None
    if metric_name == 'r2':
        metric_instance = torchmetrics.regression.R2Score()
    elif metric_name == 'smape':
        metric_instance = torchmetrics.regression.SymmetricMeanAbsolutePercentageError()
    else:
        raise RuntimeError(f"Unknown metric: {metric_name}")
    
    # Ensure tensors are on the same device
    metric_instance = metric_instance.to(y_pred.device)
    score = metric_instance(y_pred, y_target).cpu().item()
    return score

def train(dataloader: DataLoader, model: nn.Module, loss_fn, optimizer):
    """
    Train the model using the TI example's training logic.
    --- MODIFIED to handle PumpDataModule batches ---
    """
    avg_loss = 0
    lambda_reg = 1e-4
    model.train()
    for batch, (X, y,_) in enumerate(dataloader):
        X, y = X.to(DEVICE), y.to(DEVICE)
        # make predictions for the current batch
        X = X.unsqueeze(1).unsqueeze(3)
        pred = model(X)
        # compute the loss
        loss = loss_fn(pred, y)
        l1_norm = sum(p.abs().sum() for p in model.parameters())
        l2_norm = sum(p.pow(2.0).sum() for p in model.parameters())

        loss += (lambda_reg*(l1_norm))
        #loss += (lambda_reg*(l2_norm)) # L2 regularization is already computed in the optimizer
        # zero the gradients for every batch
        optimizer.zero_grad()
        # do backpropagation
        loss.backward()
        # adjust the learning weights
        optimizer.step()

        avg_loss += loss.item()
    avg_loss = avg_loss/len(dataloader)
    return avg_loss, model, loss_fn, optimizer

def calibrate(dataloader: DataLoader, model: nn.Module, loss_fn):
    """
    Calibrate the model for PTQ.
    --- MODIFIED to handle PumpDataModule batches ---
    """
    avg_loss = 0
    model.train()
    with torch.no_grad():
        # --- MODIFIED: The batch format now comes from PumpDataModule ---
        for inputs, labels, _ in dataloader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            
            # --- MODIFIED: Reshape input for Conv2d layers ---
            inputs_4d = inputs.unsqueeze(1).unsqueeze(3)
            
            pred = model(inputs_4d)
            loss = loss_fn(pred, labels)
            avg_loss += loss.item()

    avg_loss = avg_loss / len(dataloader)
    return avg_loss, model, loss_fn, None

def get_nn_model(in_channels: int, hidden_channels: List[int], feature_size: Tuple[int], out_channels: int, normalize_input: bool = True) -> nn.Module:
    """The exact model architecture from the TI example."""
    def get_conv_bn_relu(in_channels: int, out_channels: int, kernel_size: Tuple[int], padding=None, stride=1):
        padding = padding or (kernel_size[0] // 2, kernel_size[1] // 2)
        layers = [
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding, stride=stride),
            nn.BatchNorm2d(num_features=out_channels),
            nn.ReLU()
        ]
        return layers
    
    class NeuralNetwork(nn.Module):
        def __init__(self):
            super().__init__()
            layers = []
            if normalize_input:
                layers += [nn.BatchNorm2d(num_features=in_channels)]
            else:
                layers += [nn.Identity()]

            in_ch = in_channels
            for h_ch in hidden_channels:
                layers += get_conv_bn_relu(in_ch, h_ch, kernel_size=(3, 1), padding=None, stride=(2, 1))
                in_ch = h_ch
                
            # --- MODIFIED: Use AvgPool2d instead of MaxPool2d ---
            layers += [nn.AvgPool2d(kernel_size=(4, 1), stride=(1, 1))]
            #layers += [nn.AdaptiveAvgPool2d(output_size=feature_size)]
            
            # --- MODIFIED: Correctly calculate flattened size for varying input ---
            # We determine this dynamically later. For now, just a placeholder.
            in_fc_ch = (in_ch * feature_size[0] * feature_size[1])

            layers += [nn.Flatten()]
            layers += [nn.Linear(in_fc_ch, out_features=64)]
            layers += [nn.ReLU()]
            layers += [nn.Linear(64, out_features=out_channels)]
            self.layers = nn.ModuleList(layers)

        def forward(self, x: torch.Tensor):
            for layer in self.layers:
                x = layer(x)
            return x

    nn_model = NeuralNetwork().to(DEVICE)
    return nn_model

def train_model(model: nn.Module, dataloader: DataLoader, val_loader: DataLoader, total_epochs: int, learning_rate: float,QAT = False) -> nn.Module:
    """The exact training setup from the TI example."""
    loss_fn = torch.nn.MSELoss()
    if QAT:
        opti = torch.optim.SGD(params=model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-4)
    else:
        opti = torch.optim.AdamW(params=model.parameters(), lr=learning_rate, weight_decay=1e-4)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opti, total_epochs)
    
    for epoch in range(total_epochs):
        train_loss, model, loss_fn, opti = train(dataloader, model, loss_fn, opti)
        val_loss = validate_epoch(val_loader, model, loss_fn)
        scheduler.step()
        last_lr = scheduler.get_last_lr()[0]
        print(f"Epoch: {epoch+1}\t LR: {round(last_lr,5)}\t Train Loss: {round(train_loss, 5)}\t Val Loss: {round(val_loss, 5)}")

    return model

def calibrate_model(model: nn.Module, dataloader: DataLoader, total_epochs: int) -> nn.Module:
    """The exact calibration setup from the TI example."""
    loss_fn = torch.nn.HuberLoss()
    for epoch in range(total_epochs):
        loss, model, loss_fn, opti = calibrate(dataloader, model, loss_fn)
        print(f"Epoch: {epoch+1}\t Loss: {round(loss, 5)}")
    return model

def validate_epoch(dataloader: DataLoader, model: nn.Module, loss_fn):
    """
    Run a single validation epoch to get the average loss.
    """
    avg_loss = 0
    model.eval()  # Set model to evaluation mode (crucial for BN)
    with torch.no_grad():  # Disable gradient calculation
        for inputs, labels, _ in dataloader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            
            # Reshape input for Conv2d layers
            inputs_4d = inputs.unsqueeze(1).unsqueeze(3)
            
            pred = model(inputs_4d)
            loss = loss_fn(pred, labels)
            avg_loss += loss.item()

    avg_loss = avg_loss / len(dataloader)
    return avg_loss

def export_model(nn_model, example_input: torch.Tensor, model_name: str, with_quant: bool = False) -> nn.Module:
    """The exact export logic from the TI example."""
    nn_model.to(DEVICE)
    if with_quant:
        nn_model = nn_model.convert()

    if with_quant and hasattr(nn_model, "export"):
        nn_model.export(example_input.to(DEVICE), model_name, input_names=['input'])
    else:
        torch.onnx.export(nn_model, example_input.to(DEVICE), model_name, input_names=['input'])
    return nn_model

def validate_model(model: nn.Module, test_loader: DataLoader) -> float:
    """
    Validation logic adapted for multi-output regression on scaled data.
    """
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels, _ in test_loader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            
            # --- MODIFIED: Reshape input for Conv2d layers ---
            inputs_4d = inputs.unsqueeze(1).unsqueeze(3)
            
            pred = model(inputs_4d)
            all_preds.append(pred)
            all_labels.append(labels)

    all_preds = torch.cat(all_preds, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    
    # Calculate metrics on the scaled data
    r2_score = metric(all_preds, all_labels, 'r2')
    smape_score = metric(all_preds, all_labels, 'smape')
    return r2_score, smape_score

if __name__ == '__main__':
    # --- CONFIGURATION (Partly from TI example, partly from yours) ---
    MODEL_NAME = "pump_model_ti_arch.onnx"
    DATA_FILE = '../f28p55-ai-control/FF-control/dataset/rampupV6_train.npz' # Your data file
    NUM_EPOCHS = 60
    BATCH_SIZE = 2048
    LEARNING_RATE = 0.001
    QUANTIZATION_METHOD = 'QAT'
    QUANTIZATION_DEVICE_TYPE = 'TINPU'
    NORMALIZE_INPUT = True
    NUM_WORKERS = 2
    # --- NEW: Setup your DataModule ---
    print("--- Setting up PumpDataModule ---")
    dm = PumpDataModule(
        data_path=DATA_FILE,
        batch_size=BATCH_SIZE,
        train_ratio=0.9,
        num_workers=1 
    )
    dm.setup()
    train_loader = dm.train_dataloader()
    test_loader = dm.val_dataloader()
    print("✅ DataModule setup complete.")

    # --- MODIFIED: Get model parameters from your DataModule ---
    IN_CHANNELS = 1 # Your data is a 1D curve
    NUM_TARGETS = dm.output_size # Should be 3

    # Get a correctly shaped example input for QAT config
    example_input, _, _ = next(iter(train_loader))
    example_input = example_input[:1].unsqueeze(1).unsqueeze(3)

    # Use the TI model architecture
    nn_model = get_nn_model(
        in_channels=IN_CHANNELS, 
        hidden_channels=[8, 16, 32], 
        feature_size=(4, 1), 
        out_channels=NUM_TARGETS, 
        normalize_input=NORMALIZE_INPUT
    )
    
    torchinfo.summary(nn_model, input_data=example_input.to(DEVICE))

    # --- FP32 Training ---
    print("\n--- Starting FP32 Training ---")
    nn_model = train_model(nn_model, train_loader,test_loader, NUM_EPOCHS, LEARNING_RATE,QAT=False)
    r2_score, mape_score = validate_model(nn_model, test_loader)
    print(f"\nTrained FP32 Model R2-Score: {round(r2_score, 5)}\tSMAPE: {round(mape_score, 3)}\n")

    DEVICE = 'cpu'
    nn_model = nn_model.to(DEVICE)
    # --- Quantization ---
    if QUANTIZATION_METHOD in ('QAT', 'PTQ'):
        MODEL_NAME = 'quant_' + MODEL_NAME
        quant_epochs = max(NUM_EPOCHS // 2, 5)
        quant_model = TINPUTinyMLQATFxModule(
            nn_model, total_epochs=quant_epochs, example_inputs=example_input
        )

        if QUANTIZATION_METHOD == 'QAT':
            print("\n--- Starting QAT Fine-Tuning ---")
            quant_learning_rate = LEARNING_RATE/10
            quant_model = train_model(quant_model, train_loader,test_loader, quant_epochs, quant_learning_rate,QAT=True)
        elif QUANTIZATION_METHOD == 'PTQ':
            print("\n--- Starting PTQ Calibration ---")
            quant_model = calibrate_model(quant_model, train_loader, quant_epochs)

        r2_score, mape_score = validate_model(quant_model, test_loader)
        print(f"\n{QUANTIZATION_METHOD} Model R2-Score: {round(r2_score, 5)}\tSMAPE: {round(mape_score, 3)}\n")
        
        print("--- Exporting Quantized Model ---")
        export_model(quant_model, example_input, MODEL_NAME, with_quant=True)
        print(f"✅ Exported quantized model to {MODEL_NAME}")
    else:
        print("No Quantization method specified.")