# data.py
import torch
from torch.utils.data import DataLoader, Dataset
import pytorch_lightning as pl
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler,MinMaxScaler

class PumpCurveDataset(Dataset):
    """
    Custom PyTorch Dataset that serves randomly sliced and padded pump curves.
    """
    def __init__(self, X_data, y_data, max_len, min_slice_perc=0.1,is_train=True,scaled_noise_std=None):
        self.X_data = X_data
        self.y_data = y_data
        self.max_len = max_len
        self.min_len = int(max_len * min_slice_perc)
        self.is_train = is_train
        self.scaled_noise_std = scaled_noise_std

    def __len__(self):
        return len(self.X_data)

    def __getitem__(self, idx):
        full_curve = self.X_data[idx]
        label = self.y_data[idx]

        # 1. Randomly choose a crop length
        crop_len = np.random.randint(self.min_len, self.max_len + 1)
        
        # 2. Slice the curve from the beginning
        cropped_curve = full_curve[:crop_len]

        # Add noise
        if self.is_train: # ONLY apply to training data
            noise_level = 0 #ml/min of noise
            noise = np.random.randn(*cropped_curve.shape) * noise_level * self.scaled_noise_std[:crop_len] # Scales the noise
            cropped_curve = cropped_curve + noise

        # 3. Create a padded tensor and fill it (right padding)
        padded_curve = np.zeros(self.max_len, dtype=np.float32)
        padded_curve[:crop_len] = cropped_curve
        
        return torch.tensor(padded_curve, dtype=torch.float32), torch.tensor(label, dtype=torch.float32),torch.tensor(crop_len, dtype=torch.int32)


class PumpDataModule(pl.LightningDataModule):
    """
    PyTorch Lightning DataModule to handle data loading, splitting, and scaling.
    """
    def __init__(self, data_path, batch_size=64, train_ratio=0.9, num_workers=4):
        super().__init__()
        self.data_path = data_path
        self.batch_size = batch_size
        self.train_ratio = train_ratio
        self.num_workers = num_workers
        
        # Initialize scalers for both inputs (X) and labels (y)
        self.x_scaler = StandardScaler()
        #self.x_scaler = MinMaxScaler(feature_range=(0, 1))
        self.y_scaler = StandardScaler()

    def setup(self, stage=None):
        try:
            data = np.load(self.data_path)
            X, y = data['X'], data['y']
        except FileNotFoundError:
            raise FileNotFoundError(f"Data file not found at '{self.data_path}'. Please generate it first.")

        self.input_size = X.shape[1]
        self.output_size = y.shape[1]

        # Split data
        X_train, X_val, y_train, y_val = train_test_split(
            X, y, test_size=(1.0 - self.train_ratio), random_state=42
        )
        
        # --- V3 CHANGE: Transform labels to log/symlog space before scaling ---
        # Assuming the parameter order is ['alpha2', 'gamma2', 'delta']
        # If your order is different, adjust the column indices below.
        pos_cols = [0, 1]    # Columns for alpha2, gamma2
        signed_col_idx = 2   # Column for delta

        # Create copies to avoid modifying the original arrays
        y_train_transformed = y_train.copy()
        y_val_transformed = y_val.copy()

        # Apply standard log to positive-only columns
        # Adding a small epsilon to prevent log(0) issues if any value is exactly zero
        y_train_transformed[:, pos_cols] = np.log(y_train[:, pos_cols] + 1e-9)
        y_val_transformed[:, pos_cols] = np.log(y_val[:, pos_cols] + 1e-9)

        # Apply symmetric log to the signed column (delta)
        y_train_transformed[:, signed_col_idx] = np.sign(y_train[:, signed_col_idx]) * np.log1p(np.abs(y_train[:, signed_col_idx]))
        y_val_transformed[:, signed_col_idx] = np.sign(y_val[:, signed_col_idx]) * np.log1p(np.abs(y_val[:, signed_col_idx]))
        # --- END V3 CHANGE ---
        """
        min_global = X_train.min()
        max_global = X_train.max()
        
        # 2. Applique ce scaling manuellement
        X_train_scaled = (X_train - min_global) / (max_global - min_global)
        X_val_scaled = (X_val - min_global) / (max_global - min_global)
        
        # 3. (Optionnel) Sauvegarde les stats dans le scaler pour info
        self.x_scaler.min_ = np.array([min_global])
        self.x_scaler.scale_ = np.array([max_global - min_global])
        self.x_scaler.data_min_ = np.array([min_global])
        self.x_scaler.data_max_ = np.array([max_global])
        self.x_scaler.data_range_ = np.array([max_global - min_global])
        """
        # Fit and transform the NEWLY TRANSFORMED labels
        X_train_scaled = self.x_scaler.fit_transform(X_train)
        X_val_scaled = self.x_scaler.transform(X_val)
        y_train_scaled = self.y_scaler.fit_transform(y_train_transformed)
        y_val_scaled = self.y_scaler.transform(y_val_transformed)

        # Create custom datasets
        noise_scale = 1.0 / self.x_scaler.scale_
        self.train_dataset = PumpCurveDataset(X_train_scaled, y_train_scaled, max_len=self.input_size,is_train=True,scaled_noise_std=noise_scale)
        self.val_dataset = PumpCurveDataset(X_val_scaled, y_val_scaled, max_len=self.input_size,is_train=False)

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            persistent_workers=True,
            pin_memory=True
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            persistent_workers=True,
            pin_memory=True
        )
    
    def unscale_predictions(self, preds_tensor):
        """
        Helper to convert scaled model output back to real physical values.
        This now includes the inverse log/symlog transforms.
        """
        if isinstance(preds_tensor, torch.Tensor):
            preds_tensor = preds_tensor.cpu().numpy()
        
        # V3 CHANGE: Perform the full inverse transformation
        # Step 1: Inverse scale to get back to the log/symlog space
        unscaled_transformed_preds = self.y_scaler.inverse_transform(preds_tensor)

        # Create a copy to store the final real values
        preds_real = unscaled_transformed_preds.copy()

        # Step 2: Apply the inverse of the log/symlog transforms
        # Assuming the parameter order is ['alpha2', 'gamma2', 'delta']
        pos_cols = [0, 1]
        signed_col_idx = 2

        # Inverse of standard log is exponent
        preds_real[:, pos_cols] = np.exp(unscaled_transformed_preds[:, pos_cols])

        # Inverse of symmetric log
        symlog_vals = unscaled_transformed_preds[:, signed_col_idx]
        preds_real[:, signed_col_idx] = np.sign(symlog_vals) * np.expm1(np.abs(symlog_vals))
        
        return preds_real