# run_qat_torchao.py
import torch
import torch.nn as nn
import os
import joblib
import importlib.util # <--- For dynamic importing
from tqdm import tqdm

# --- Custom Imports ---
from rampupV6.data import PumpDataModule # Data module is still a project dependency
from torchao.quantization import quantize_, Int8DynamicActivationInt4WeightConfig
from torchao.quantization.qat import QATConfig

# ==============================================================================
#      CONFIGURATION
# ==============================================================================
# This is the ONLY path you need to change
ARTIFACT_DIR = "training_runs/FullParamsEmbeddedSmall_20251021-143330/packaged_model"

# QAT settings
DATA_FILE = "dataset/rampupV6_train.npz"
FINE_TUNE_EPOCHS = 5 # torchao is often faster
LEARNING_RATE = 1e-6
BATCH_SIZE = 2048
DEVICE = 'cpu'

# These constants must match what was used in train.py
SIM_CONSTANTS_FOR_LOSS = {
    'alpha1': 4.03, 'gamma1': 2.14e-02, 'C': 6e-4,
    'max_rpm': 13000.0, 'max_rpm_acc' : 25400, 'dt' : 0.0025
} 
# ==============================================================================

# (Helper functions are identical to the TI script)
def dynamically_load_model_class(artifact_dir):
    model_def_path = os.path.join(artifact_dir, "model_definition.py")
    if not os.path.exists(model_def_path):
        raise FileNotFoundError(f"Cannot find 'model_definition.py' in {artifact_dir}")
    spec = importlib.util.spec_from_file_location("model_definition", model_def_path)
    model_module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(model_module)
    if not hasattr(model_module, 'PumpParameterRegressor'):
         raise ImportError(f"Could not find 'PumpParameterRegressor' in {model_def_path}")
    print(f"✅ Dynamically loaded model definition from: {model_def_path}")
    return model_module.PumpParameterRegressor

def evaluate_model(model, dataloader, device='cpu'):
    model.eval()
    loss_fn = nn.MSELoss()
    total_loss = 0
    with torch.no_grad():
        for inputs, labels_scaled, _ in tqdm(dataloader, desc="Evaluating"):
            inputs = inputs.to(device).unsqueeze(1).unsqueeze(3)
            labels_scaled = labels_scaled.to(device)
            predictions = model(inputs)
            loss = loss_fn(predictions, labels_scaled)
            total_loss += loss.item()
    return total_loss / len(dataloader)


if __name__ == '__main__':

    # --- 1. Load Data ---
    print("--- 1. Loading Data ---")
    dm = PumpDataModule(data_path=DATA_FILE, batch_size=BATCH_SIZE, train_ratio=0.9)
    dm.setup()
    train_dataloader = dm.train_dataloader()
    val_dataloader = dm.val_dataloader()
    print("✅ Data loaded.")

    # --- 2. Dynamically Load the "White Box" Model ---
    print("\n--- 2. Loading 'White Box' Model from Artifact Package ---")
    CHECKPOINT_PATH = os.path.join(ARTIFACT_DIR, "best_model.ckpt")
    SCALERS_PATH = os.path.join(ARTIFACT_DIR, "scalers.joblib")
    
    PumpParameterRegressor = dynamically_load_model_class(ARTIFACT_DIR)
    scalers = joblib.load(SCALERS_PATH)
    
    model_pl = PumpParameterRegressor.load_from_checkpoint(
        CHECKPOINT_PATH,
        y_scaler=scalers['y_scaler'],
        constants=SIM_CONSTANTS_FOR_LOSS
    )
    model_fp32 = model_pl.model.to(DEVICE).eval() # Get the raw nn.Module
    print("✅ 'White Box' model loaded successfully.")

    # --- 3. Establish FP32 Baseline ---
    print("\n--- 3. Establishing FP32 Performance Baseline ---")
    baseline_loss = evaluate_model(model_fp32, val_dataloader, device=DEVICE)
    print(f"🎯 FP32 Baseline MSE Loss: {baseline_loss:.6f}")

    # --- 4. Prepare and Run torchao QAT ---
    print("\n--- 4. Wrapping model for torchao QAT ---")
    
    # 1. Define base quantization config
    base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
    # 2. Create the QAT "prepare" configuration
    qat_config = QATConfig(base_config, step="prepare")
    # 3. Apply the "prepare" step
    quantize_(model_fp32.train(), qat_config)
    
    optimizer = torch.optim.SGD(model_fp32.parameters(), lr=LEARNING_RATE)
    loss_fn = nn.MSELoss()

    print("\n--- 5. Starting QAT Fine-Tuning ---")
    for epoch in range(FINE_TUNE_EPOCHS):
        model_fp32.train()
        progress_bar = tqdm(train_dataloader, desc=f"QAT Epoch {epoch+1}/{FINE_TUNE_EPOCHS}")
        for inputs, labels, _ in progress_bar:
            optimizer.zero_grad()
            predictions = model_fp32(inputs.to(DEVICE).unsqueeze(1).unsqueeze(3))
            loss = loss_fn(predictions, labels.to(DEVICE))
            loss.backward()
            optimizer.step()
            progress_bar.set_postfix({'loss': loss.item()})
    print("✅ QAT Fine-tuning complete.")

    # --- 6. Convert, Evaluate, and Export INT8 Model ---
    print("\n--- 6. Converting, Evaluating, and Exporting torchao INT8 Model ---")
    
    # 1. Create a new QAT config for the "convert" step
    convert_qat_config = QATConfig(base_config, step="convert")
    # 2. Apply the "convert" step
    quantize_(model_fp32.eval(), convert_qat_config)
    model_int8 = model_fp32
    quantized_loss = evaluate_model(model_int8, val_dataloader, device=DEVICE)
    
    print("\n--- 📊 torchao QAT RESULTS ---")
    print(f"FP32 Baseline Loss:   {baseline_loss:.6f}")
    print(f"INT8 QAT Loss:        {quantized_loss:.6f}")
    print(f"Performance Change:   {quantized_loss - baseline_loss:+.6f}")

    # Export the new INT8 'black box' model
    try:
        import torch.export 
        print("\nExporting 'black box' (model_int8_torchao.ep)...")
        dummy_input = torch.randn(1, 1, dm.input_size, 1, device=DEVICE)
        exported_int8_program = torch.export.export(model_int8.to('cpu'), (dummy_input.to('cpu'),))
        
        exported_int8_path = os.path.join(ARTIFACT_DIR, "model_int8_torchao.ep")
        torch.export.save(exported_int8_program, exported_int8_path)
        print(f"✅ Exported torchao INT8 'black box' to: {exported_int8_path}")

    except Exception as e:
        print(f"!!!!!!!!!!!!!\nFailed to export INT8 model: {e}\n!!!!!!!!!!!!!")