from .adc12qj1600_header import *

import time

class ADC12QJ1600():

    reg_size_mask = 0xFF
    def __init__(self, device_name, sampling_frequency, spi_interface, sys_ref_term='diff', dev_clk_term='diff', cpll_ref_freq=50e6) -> None:

        print('***************************************************************************')
        print('\tInitializing ADC12QJ1600 instance with defined attributes')

        self.dev_name = device_name
        self.sampling_frequency = sampling_frequency    
        self.spi_interface = spi_interface

        self.sys_ref_term = sys_ref_term
        self.dev_clk_term = dev_clk_term
        self.cpll_ref_freq = cpll_ref_freq

        self.reg_size_mask = 0xFF

        self._P = [1,2,4]
        self._V = [5,4,3]
        
        self.f_pll_pfd_min = 50e6
        self.f_pll_pfd_max = 500e6

        self.f_vco_min = 7.2e9
        self.f_vco_max = 8.2e9

        self.jesd_en = 0
        self.cal_en = 0
        self.sysref_rx_en = 0

        self.p = []

        self.device_reset()
        start_timeout = 10000
        startup_count = 0
        init_bit = 0
        while(startup_count < start_timeout):
            init_bit = self.reg_field_read(INIT_STATUS_INIT_DONE_ENUM)
            startup_count += 1
            if init_bit == 1:
                print("\tGot Init Bit")
                break

        if init_bit == 0:
            SystemExit("Error! NEver got initialized")
        else:	
            print("\tADC device initialization complete successfuly\n\tADC is ready for programming")
            print('***************************************************************************\n')


    #############################################################
    # MACROS
    #############################################################

    def configure_cpll(self, ref_clock_source):

        self.configure_cpll_vco(True)
        self.cpll_spi_override_enable()

        if ref_clock_source == "DIFF":
            self.cpll_DIFF_ref_clk()
        elif ref_clock_source == "SE":
            self.cpll_SE_ref_clk()
            
        self.cpll_enable()
        self.cpll_reset_enable()

        self.p, self.v, self.n, self.f_vco = self.calculate_dividers(self.sampling_frequency, self.cpll_ref_freq)
        print(f"P = {self.p}, V = {self.v}, N = {self.n}, FVCO = {self.f_vco}")
        self.set_p_divider(self.p)
        self.set_v_divider(self.v)
        self.set_n_divider(self.n)

        self.set_vco_cal_settling_time(0x04)
        self.vco_cal_enable()
        self.cpll_reset_disable()

    def configure_trigout_fpga_ref(self, refclk_ratio=32):
        
        self.set_trigout_divider(refclk_ratio)
        self.trigout_fpga_output_enable()

    def configure_clock_noise_supression(self):
        self.clock_ctrl_2_reserved_write()
        self.va11q_noise_suppr_enable()
        self.vclk11_noise_suppr_enable()

    def configure_jesd(self, jmode, k):
        self.jesd_disable()
        self.cal_disable()

        self.set_jmode(jmode)
        self.set_k(k)
        self.scrambling_enable()
        self.reg_field_write(JCTRL_SFORMAT_ENUM, 0)
        self.configure_ovr_cfg()

        self.cal_enable()
        self.jesd_enable()
        

    def toggle_software_cal(self):
        self.soft_trig_disable()
        self.soft_trig_enable()


    #############################################################
    # Low Level Functions
    #############################################################

    def device_reset(self):
        '''After the soft reset bit is brought high you must wait at least 750ns before
		   writing any other registers'''
        self.reg_field_write(CONFIG_A_SOFT_RESET_ENUM, CONFIG_A_SOFT_RESET_ENABLE)
        time.sleep(800/1e6)

    def configure_cpll_vco(self, cpll_on):
        if cpll_on:
            self.reg_field_write(CPLL_VCOCTRL1_VCO_BIAS_ENUM, CPLL_VCOCTRL1_VCO_BIAS_CPLL_EN)
        else:
            self.reg_field_write(CPLL_VCOCTRL1_VCO_BIAS_ENUM, CPLL_VCOCTRL1_VCO_BIAS_CPLL_DIS)

    def cpll_spi_override_enable(self):
        self.reg_field_write(CPLL_OVR_CPLL_OVR_EN_ENUM, 0x01)

    def cpll_spi_override_disable(self):
        self.reg_field_write(CPLL_OVR_CPLL_OVR_EN_ENUM, 0x00)

    def cpll_SE_ref_clk(self):
        self.reg_field_write(CPLL_OVR_CPLLREF_SE_OVR_VALUE_ENUM, 0x01)

    def cpll_DIFF_ref_clk(self):
        self.reg_field_write(CPLL_OVR_CPLLREF_SE_OVR_VALUE_ENUM, 0x00)

    def cpll_enable(self):
        self.reg_field_write(CPLL_OVR_CPLL_EN_OVR_VALUE_ENUM, 0x01)

    def cpll_disable(self):
        self.reg_field_write(CPLL_OVR_CPLL_EN_OVR_VALUE_ENUM, 0x00)

    def cpll_reset_enable(self):
        self.reg_field_write(CPLL_RESET_ENUM, 0x01)

    def cpll_reset_disable(self):
        self.reg_field_write(CPLL_RESET_ENUM, 0x00)

    def set_p_divider(self, p):

        if p == 1:
            data = CPLL_FB_DIV1_PLL_P_DIV_DIV_1
        elif p == 2:
            data = CPLL_FB_DIV1_PLL_P_DIV_DIV_2
        elif p == 4:
            data = CPLL_FB_DIV1_PLL_P_DIV_DIV_4
        else:
            raise SystemExit("Invalid P Value")
        
        self.reg_field_write(CPLL_FB_DIV1_PLL_P_DIV_ENUM, data)

    def set_v_divider(self, v):

        if v == 5:
            data = CPLL_FB_DIV1_PLL_V_DIV_DIV_5
        elif v == 4:
            data = CPLL_FB_DIV1_PLL_V_DIV_DIV_4
        elif v == 3:
            data = CPLL_FB_DIV1_PLL_V_DIV_DIV_3
        else:
            raise SystemExit("Invalid V Value")
        
        self.reg_field_write(CPLL_FB_DIV1_PLL_V_DIV_ENUM, data)

    def set_n_divider(self, n):
        if n > 63 or n < 1:
            raise SystemExit("Invalid N Value")
        
        self.reg_field_write(CPLL_FBDIV2_PLL_N_DIV_ENUM, n)

    def set_vco_cal_settling_time(self, val):
        self.reg_field_write(VCO_CAL_CTRL_VCO_CAL_STL_ENUM, val)

    def vco_cal_enable(self):
        self.reg_field_write(VCO_CAL_CTRL_VCO_CAL_EN_ENUM, 0x01)

    def vco_cal_disable(self):
        self.reg_field_write(VCO_CAL_CTRL_VCO_CAL_EN_ENUM, 0x00)

    def trigout_fpga_output_enable(self):
        self.reg_field_write(TRIGOUT_CTRL_TRIGOUT_EN_ENUM, 0x01)
    
    def trigout_fpga_output_disable(self):
        self.reg_field_write(TRIGOUT_CTRL_TRIGOUT_EN_ENUM, 0x00)

    def set_trigout_divider(self, rx_div):
        
        if rx_div == 16:
            data = TRIGOUT_CTRL_TRIGOUT_DIV_16
        elif rx_div == 32:
            data = TRIGOUT_CTRL_TRIGOUT_DIV_32
        elif rx_div == 64:
            data = TRIGOUT_CTRL_TRIGOUT_DIV_64
        elif rx_div == -1:
            data = TRIGOUT_CTRL_TRIGOUT_DIV_TMSTP
        else:
            raise SystemExit("Invalid rx_div_provided")
        
        self.reg_field_write(TRIGOUT_CTRL_TRIGOUT_DIV_ENUM, data)

    def calculate_dividers(self, fs, f_ref):

        if f_ref < self.f_pll_pfd_min or f_ref > self.f_pll_pfd_max:
            raise SystemExit("Invalid reference frequency, the given cpll reference frequency is not in the range of the PFD")
        
        n = fs/f_ref
        if n % 1 != 0:
            raise SystemExit("Invalid sampling frequency, N (fs/f_ref) must be a integer")

        f_vco = None
        for _p in self._P:
            for _v in self._V:
                _pv = _p*_v
                _f_vco = fs*_pv
                if _f_vco >= self.f_vco_min and _f_vco <= self.f_vco_max:
                    p = int(_p)
                    v = int(_v)
                    f_vco = int(_f_vco)
                    break

        if f_vco is None:
            raise SystemExit("Invalid sampling frequency, could not find a valid VCO frequency")

        if (f_vco/(p*v*n)) != f_ref:
            raise SystemExit("Invalid sampling frequency, Invalid Divider settings")
        
        return int(p), int(v), int(n), float(f_vco)

    def clock_ctrl_2_reserved_write(self):
        self.reg_field_write(CLK_CTRL2_RESERVED_ENUM, 0x02)

    def va11q_noise_suppr_enable(self):
        self.reg_field_write(CLK_CTRL2_VA11Q_NOISESUPPR_EN, 0x01)

    def va11q_noise_suppr_disable(self):
        self.reg_field_write(CLK_CTRL2_VA11Q_NOISESUPPR_EN, 0x00)

    def vclk11_noise_suppr_enable(self):
        self.reg_field_write(CLK_CTRL2_VCLK11_NOISESUPPR_EN, 0x01)

    def vclk11_noise_suppr_disable(self):
        self.reg_field_write(CLK_CTRL2_VCLK11_NOISESUPPR_EN, 0x00)

    def jesd_disable(self):
        if self.jesd_en:
            self.reg_field_write(JESD_EN_ENUM, JESD_EN_DISABLE_JESD)
            self.jesd_en = 0

    def jesd_enable(self):
        if not self.jesd_en:
            self.reg_field_write(JESD_EN_ENUM, JESD_EN_ENABLE_JESD)
            self.jesd_en = 1

    def cal_disable(self):
        if self.cal_en:
            self.reg_field_write(CAL_EN_ENUM, 0)
            self.cal_en = 0

    def cal_enable(self):
        if not self.cal_en:
            self.reg_field_write(CAL_EN_ENUM, 1)
            self.cal_en = 1

    def set_jmode(self, jmode):
        '''Function to set the ADC's JMODE. 
            In order to first set the JMODE jesd must be disabled and cal must be disable
            Defaults to zero'''

        print("Setting the ADC's JMODE to " + str(jmode))

        re_en_jesd = 0
        re_en_cal = 0
        if self.jesd_en:
            re_en_jesd = 1
            self.jesd_disable()

        if self.cal_en:
            re_en_cal = 1
            self.cal_disable()

        # Code to check jmode validity
        if (jmode > 0 and jmode <= 14):
            self.reg_field_write(JMODE_ENUM, jmode)
        else:
            SystemExit("Invalid JMODE provided")

        if re_en_jesd:
            self.jesd_enable()

        if re_en_cal:
            self.cal_enable()

    def set_k(self,k):

        re_en_jesd = 0
        re_en_cal = 0
        if self.jesd_en:
            re_en_jesd = 1
            self.jesd_disable()

        if self.cal_en:
            re_en_cal = 1
            self.cal_disable()

        #need more robust check for k based on jmode
        if (k <= 256):
            self.reg_field_write(KM1_ENUM, k-1)
        else:
            SystemExit("Invalid K parameter provided")

        if re_en_jesd:
            self.jesd_enable()

        if re_en_cal:
            self.cal_enable()

    def set_format(self, mode='offset_bin'):
        
        re_en_jesd = 0
        re_en_cal = 0
        if self.jesd_en:
            re_en_jesd = 1
            self.jesd_disable()

        if self.cal_en:
            re_en_cal = 1
            self.cal_disable()

        if mode == 'offset_bin':
            self.reg_field_write(JCTRL_SFORMAT_ENUM, JCTRL_SFORMAT_OFFSET_BIN)
        elif mode == '2_comp':
            self.reg_field_write(JCTRL_SFORMAT_ENUM, JCTRL_SFORMAT_SIGN_2_COMPLEMENT)
        else:
            SystemExit("Invalid Output format mode provided")

        if re_en_jesd:
            self.jesd_enable()

        if re_en_cal:
            self.cal_enable()

    def scrambling_enable(self):
        re_en_jesd = 0
        re_en_cal = 0
        if self.jesd_en:
            re_en_jesd = 1
            self.jesd_disable()

        if self.cal_en:
            re_en_cal = 1
            self.cal_disable()

        self.reg_field_write(JCTRL_SCR_ENUM, JCTRL_SCR_8B10B_SCRAMBLE_ENABLED)

        if re_en_jesd:
            self.jesd_enable()

        if re_en_cal:
            self.cal_enable()

    def scrambling_disable(self):
        re_en_jesd = 0
        re_en_cal = 0
        if self.jesd_en:
            re_en_jesd = 1
            self.jesd_disable()

        if self.cal_en:
            re_en_cal = 1
            self.cal_disable()

        self.reg_field_write(JCTRL_SCR_ENUM, JCTRL_SCR_8B10B_SCRAMBLE_DISABLED)

        if re_en_jesd:
            self.jesd_enable()

        if re_en_cal:
            self.cal_enable()

    def configure_ovr_cfg(self):
        re_en_jesd = 0
        re_en_cal = 0
        if self.jesd_en:
            re_en_jesd = 1
            self.jesd_disable()

        if self.cal_en:
            re_en_cal = 1
            self.cal_disable()

        self.reg_field_write(OVR_CFG_OVR_EN_ENUM, 0x01)
        self.reg_field_write(OVR_CFG_OVR_N_ENUM, 0x07)

        if re_en_jesd:
            self.jesd_enable()

        if re_en_cal:
            self.cal_enable()

    def soft_trig_enable(self):
        self.reg_field_write(CAL_SOFT_TRIG_ENUM, 1)

    def soft_trig_disable(self):
        self.reg_field_write(CAL_SOFT_TRIG_ENUM, 0)

    #############################################################
    # Insert Constant Write and Read functions here
    #############################################################

    #############################################################
    # hw_write performs the low level writing of the data to the register
    #############################################################
    def hw_write(self, addr, data):
        return self.spi_interface.write(addr, data)

    #############################################################
    # hw_read performs the low level reading of the data from the register
    #############################################################
    def hw_read(self, addr):
        return self.spi_interface.read(addr)

    #############################################################
    # Function to set a specific field in a register. Does not
    # carry out the actual write to the register
    #############################################################
    @staticmethod
    def set_reg_field(reg_value, field_offset, field_mask, field_value):
        reg_val = reg_value & ADC12QJ1600.reg_size_mask
        reg_val_mask = field_mask << field_offset
        reg_val = (reg_val & ~reg_val_mask) | (field_value << field_offset)
        return reg_val

    #############################################################
    # Function to extract a specific field of a register (returns right aligned)
    # Does not carry out the actual read from the register
    #reg_value, field_offset, field_mask
    #############################################################
    @staticmethod
    def get_reg_field(reg_value, field_offset, field_mask):
        field_val = (reg_value >> field_offset) & field_mask
        return field_val

    #############################################################
    # Function to read ADC register
    #############################################################

    def reg_read(self, addr):
        return self.hw_read(addr)

    #############################################################
    # Function to write ADC register
    #############################################################
    def reg_write(self, addr, data):
        return self.hw_write(addr, data)

    #############################################################
    # Function to read and return a specific field of a register
    # Combines the reg_read and get_reg_field functions into one
    # Input arguments: register address, field offset, field mask
    # Returns: Field value, right justified
    #reg_addr, reg_field_offset, reg_field_mask
    #############################################################
    def reg_field_read(self, reg_info):
        reg_addr = reg_info[0]
        reg_field_offset = reg_info[1]
        reg_field_mask = reg_info[2]

        reg_val = self.reg_read(reg_addr)
        return self.get_reg_field(reg_val, reg_field_offset, reg_field_mask)

    #############################################################
    # Function to write a specific field of a register
    # Creates a read-modify-write sequence
    # Input arguments: register address, field offset, field mask, field value
    # Returns: NA
    #############################################################
    def reg_field_write(self, reg_info, field_value):
        reg_addr = reg_info[0]
        reg_field_offset = reg_info[1]
        reg_field_mask = reg_info[2]
        reg_val = self.reg_read(reg_addr)

        reg_val = self.set_reg_field(reg_val, reg_field_offset, reg_field_mask, field_value)
        self.hw_write(reg_addr, reg_val)

    def reg_field_clear(self, reg_info, field_value, bit_flag, bit_info):
        '''
        Special write function to take into account bit fields that might have a clear high when writing to the other bit fields in this register
        '''
        reg_addr = reg_info[0]
        reg_field_offset = reg_info[1]
        reg_field_mask = reg_info[2]

        reg_val = self.reg_read(reg_addr)

        reg_val = self.set_reg_field(reg_val, reg_field_offset, reg_field_mask, field_value)

        reg_val = self.set_reg_field(reg_val, bit_info[1], bit_info[2], bit_flag)

        self.hw_write(reg_addr, reg_val)


    def reg_burst_read(self, reg_mask, num_reg_req, reg_info):
        read_data = 0

        for curr_reg in range(0, num_reg_req):
            data = self.reg_read(reg_info[0]+curr_reg)
            read_data = ((data << (8*curr_reg))& reg_mask)

        return read_data

    def reg_burst_write(self, reg_info, data, num_reg_req):
        reg_size = 8
        for curr_reg in range(0,num_reg_req):
            write_data = ((data >> (reg_size*curr_reg)) & ADC12QJ1600.reg_size_mask)
            reg_val = self.reg_read(reg_info[0]+curr_reg)
            reg_val = self.set_reg_field(reg_val, reg_info[1], reg_info[2], write_data)
            self.hw_write(reg_info[0]+curr_reg, reg_val)