diff options
| author | Clément Zrounba <6691770+clement-z@users.noreply.github.com> | 2023-09-30 23:06:01 +0200 | 
|---|---|---|
| committer | Clément Zrounba <6691770+clement-z@users.noreply.github.com> | 2023-09-30 23:26:46 +0200 | 
| commit | ff9b8bb838ecdfbfc1dc81038fcf3b2a87636982 (patch) | |
| tree | 21f27be782ce11c6d00b96ce100a2bff88141b2e /utils/analyze_detector_trace.py | |
| download | specs-ff9b8bb838ecdfbfc1dc81038fcf3b2a87636982.tar.gz specs-ff9b8bb838ecdfbfc1dc81038fcf3b2a87636982.zip | |
Initial release
Diffstat (limited to 'utils/analyze_detector_trace.py')
| -rw-r--r-- | utils/analyze_detector_trace.py | 444 | 
1 files changed, 444 insertions, 0 deletions
| diff --git a/utils/analyze_detector_trace.py b/utils/analyze_detector_trace.py new file mode 100644 index 0000000..80099a3 --- /dev/null +++ b/utils/analyze_detector_trace.py @@ -0,0 +1,444 @@ +import sys +import numpy as np +import pickle +import matplotlib.pyplot as plt +import csv +from queue import PriorityQueue +import heapq +from scipy.constants import c, pi, h + + +def waveguide_length_for_phase_shift(phi, wavelength, neff): +    phi = np.mod(phi, 2*pi) +    if phi == 0: +        phi = 2*pi + +    length_req = phi * wavelength / (2 * pi * neff) + +    return length_req + + +class Pulse(): +    max_id = 0 +    def __init__(self, power, duration, phase, tstart, wavelength=1.55e-6, sig_id=0): +        self.power = power +        self.duration = duration +        self.phase = phase +        self.tstart = tstart +        self.wavelength = wavelength +        if sig_id != 0: +            self.id = sig_id +            if self.id >= Pulse.max_id: +                Pulse.max_id = self.id + 1 +        else: +            self.id = Pulse.max_id +            Pulse.max_id += 1 + +    @property +    def tend(self): +        return self.tstart + self.duration + +    @property +    def energy(self): +        return self.power * self.duration + +    def __str__(self): +        return f't={self.tstart}, P={self.power}, tau={self.duration}, phi={self.phase}, lambda={self.wavelength}, id={self.id}' + +    def _cmp_key(self): +        return (self.tstart, self.id) + +    def __lt__(self, rhs): +        return self._cmp_key() < rhs._cmp_key() + +    # def __le__(self, rhs): +    #     return self.tstart <= rhs.tstart +    # +    # def __gt__(self, rhs): +    #     return rhs < self +    # +    # def __ge__(self, rhs): +    #     return rhs <= self + +    def __eq__(self, rhs): +        #return self.tstart == rhs.tstart +        return self._cmp_key() == rhs._cmp_key() +        # return (self.power == rhs.power +        #     and self.duration == rhs.duration +        #     and self.phase == rhs.phase +        #     and self.tstart == rhs.tstart +        #     and self.wavelength == rhs.wavelength +        #     ) +     +    # def __neq__(self, rhs): +    #     return not self == rhs + +    def intersects(lhs, rhs): +        # sort by tstart +        #pulses = sorted([lhs, rhs], key=lambda x: x.tstart) + +        # if they don't interfere +        if lhs.tstart <= rhs.tstart and lhs.tend - rhs.tstart > 1e-20: +            return True +        if rhs.tstart <= lhs.tstart and rhs.tend - lhs.tstart > 1e-20: +            return True + +        return False + +    def __add__(lhs, rhs): +        # sort by tstart +        pulses = sorted([lhs, rhs]) + +        assert(pulses[0].tstart <= pulses[1].tstart) + +        # if they don't interfere, return both pulses +        if pulses[0].tend <= pulses[1].tstart: +            return pulses + +        if pulses[0].wavelength != pulses[1].wavelength: +            raise ValueError('Cannot handle addition of pulses of different frequencies with this class') + +        #print(f'Summing pulses with timestamps {pulses[0].tstart} and {pulses[1].tstart}') +        wavelength = pulses[0].wavelength + +        phi1 = pulses[0].phase +        P1 = pulses[0].power +        A1 = np.sqrt(P1) + +        phi2 = pulses[1].phase +        P2 = pulses[1].power +        A2 = np.sqrt(P2) + +        dphi = phi2 - phi1 + +        #print(f'dphi = {np.mod(dphi,2*pi)/pi:.3f}*pi') +        #print(f'dphi/2 = {np.mod(dphi/2,2*pi)/pi:.3f}*pi') +        #print(f'cos(dphi/2) = {np.cos(dphi/2)}') +        A_sum = (A1 + A2) * np.cos(dphi/2) - 1j * (A1 - A2) * np.sin(dphi/2) +        #P_sum = np.real(A_sum * np.conj(A_sum)) +        P_sum = np.abs(A_sum) ** 2 +        phi_sum = (phi1 + phi2) / 2 + np.angle(A_sum) + +        last_pulse = 1 +        if pulses[0].tend > pulses[1].tend: +            last_pulse = 0 + +        tstart_pre = pulses[0].tstart +        tstart_sum = pulses[1].tstart +        tstart_post = pulses[not last_pulse].tend + +        duration_pre = tstart_sum - tstart_pre +        duration_sum = tstart_post - tstart_sum +        duration_post = pulses[last_pulse].tend - tstart_post + +        P_post = [P1, P2][last_pulse] +        phi_post = [phi1, phi2][last_pulse] + +        assert(duration_pre >= 0) +        assert(duration_sum >= 0) +        assert(duration_post >= 0) + +        result = [] +        if duration_pre > 0: +            result.append(Pulse(P1, duration_pre, phi1, tstart_pre, wavelength)) + +        result.append(Pulse(P_sum, duration_sum, phi_sum, tstart_sum, wavelength)) + +        if duration_post > 0: +            result.append(Pulse(P_post, duration_post, phi_post, tstart_post, wavelength)) + +        return result + +    @classmethod +    def sort_by_tstart(cls, pulses): +        return sorted(pulses) + + +class PulseAggreg(): +    def __init__(self, pulses): +        self.reduced = False +        self.pulses = pulses +        heapq.heapify(self.pulses) +        self.sort_pulses_by_tstart() +        # print(str(self)) +     +    def sort_pulses_by_tstart(self): +        #heapq.heapify(self.pulses) +        self.pulses = Pulse.sort_by_tstart(self.pulses) +        pass + +    @property +    def has_intersections(self): +        if self.reduced: +            return False + +        heapq.heapify(self.pulses) +        for i,p in enumerate(self.pulses[:-1]): +            p2 = self.pulses[i+1] +            if p.intersects(p2): +                return True +        # for i,p in enumerate(self.pulses[:-1]): +        #     for i2,p2 in enumerate(self.pulses[i+1:]): +        #         if p.intersects(p2): +        #             return True +        return False + +    def reduce_pulses_coherent(self): +        if not self.has_intersections: +            self.reduced = True +            return + +        pulses = self.pulses +        pulses_new = [] +         +        n = len(pulses) +        i = 0 +        print('Reducing to independent pulses...') +        while len(pulses) > 1: +            print(f'\r{100*(1-len(pulses)/n):.0f}% ({len(pulses)} pulses left)', end='') +            #pulses = Pulse.sort_by_tstart(pulses) +            p0 = heapq.heappop(pulses) +            p1 = heapq.heappop(pulses) + +            if p0.intersects(p1): +                resulting_pulses = p0 + p1 +                #heapq.heappop(pulses) +                # print(p0) +                # print(p1) +                # print('-->') +                for px in resulting_pulses: +                    heapq.heappush(pulses, px) +                #heapq.heappush(pulses_new, resulting_pulses[0]) +                    # print(px) +                # print(' ') +            else: +                heapq.heappush(pulses_new, p0) +                heapq.heappush(pulses, p1) + + +        if len(pulses) == 1: +            heapq.heappush(pulses_new, heapq.heappop(pulses)) + +        print(f'\r{100*(1-len(pulses)/n):.0f}% ({len(pulses)} pulses left)') +        self.pulses = pulses_new + +        print('Checking...') +        if self.has_intersections: +            raise RuntimeError('Impossible to reduce pulse list') + +        with open('NonIntersectingPulses.obj', 'wb') as f: +            pickle.dump(self, f) + +        print('Done') +        self.reduced = True + +    def __str__(self): +        ret = '\n'.join([str(p) for p in list(self.pulses)]) +        return ret + +    @property +    def energy(self): +        if self.has_intersections: +            self.reduce_pulses_coherent() +        return np.sum([p.energy for p in self.pulses]) + + +    def to_waveform(self, dt, coherent=True, tmax=None): +        if self.has_intersections: +            self.reduce_pulses_coherent() + +        #print(np.min([p.duration for p in self.pulses])) +        #print(np.max([p.power for p in self.pulses])) +        #print(np.min([p.power for p in self.pulses])) +        #print(str(self)) +        #print(self.energy) + +        if tmax == None: +            tmax = self.pulses[-1].tend + +        t = np.arange(0, tmax, dt) +        Pout = np.zeros(len(t)) + +        for p in self.pulses: +            # itmin = int(np.ceil(p.tstart / dt)) +            # itmax = int(np.ceil(p.tend / dt) + 1) +            Pout[np.logical_and(t >= p.tstart, t < p.tend)] = p.power +            #Pout[itmin:itmax] = p.power + +        return t, Pout + +    def to_waveform_2(self, coherent=True): +        if self.has_intersections: +            self.reduce_pulses_coherent() + +        #print(self.energy) + +        t = [0] +        Pout = [0] + +        for p in self.pulses: +            if p.duration < 1e-20: +                continue +            if p.tstart - t[-1] > 1e-20: +                t.append(t[-1]) +                Pout.append(0) +                t.append(p.tstart) +                Pout.append(0) +            t.append(p.tstart) +            Pout.append(p.power) +            t.append(p.tend) +            Pout.append(p.power) + +        return np.array(t), np.array(Pout) + +    def plot_waveform(self, dt, title=None, tmax=None): +        print(f'Mapping to timeseries ({len(self.pulses)} pulses to process)') +        #t, Pout = self.to_waveform(dt, tmax=tmax) +        t, Pout = self.to_waveform_2() +        print('Done') + +        with open('Pout.obj', 'wb') as f: +            pickle.dump((t,Pout), f) + +        linespec='-' +        if title is not None and '.bk' in title: +            linespec='--' +        plt.plot(1e9*t, 1e3*Pout, linespec, label=title) +        if tmax is not None: +            plt.xlim([0, 1e9*tmax]) +        #plt.title(title) +        plt.xlabel('t (ns)') +        plt.ylabel('Pout (mW)') +        plt.savefig('Pout.png', dpi=300) +        plt.grid('major') +        plt.grid('minor') +        plt.legend() + + +def main(filename='detector_trace.txt', override_lambda=None): +    pulses = [] +    tmax = 0 +    taumin = 1 + +    with open(filename, 'r') as f: +        print('Reading trace file...') +        fieldnames = [h.strip() for h in next(csv.reader(f))] + +        reader = csv.DictReader(f, fieldnames=fieldnames) +        for (i, row) in enumerate(reader): +            # if i > 470: +            #     break +            #power = np.round(float(row['P (W)']), 12) +            power = float(row['P (W)']) +            #tau = np.round(float(row['tau (s)']), 16) +            tau = float(row['tau (s)']) +            phase = float(row['phi (rad)']) +            #tstart = np.round(float(row['t (s)']), 16) +            tstart = float(row['t (s)']) +            wavelength = np.round(float(row['lambda (m)']), 12) +            if override_lambda is not None: +                wavelength = override_lambda +            sig_id = int(row['id']) + +            #if tau < 1e-16: continue + +            #print(wavelength) +            p = Pulse(power, tau, phase, tstart, wavelength, sig_id) +            pulses.append(p) +            #tmax = max(tmax, p.tend) +            taumin = min(taumin, tau) + +    print(f'Smallest pulse duration: {taumin}') +    pulses = PulseAggreg(pulses) +    print(f'Some pulses intersect: {pulses.has_intersections}') +    pulses.reduce_pulses_coherent() + +    #t = np.arange(0, tmax, 10e-12) +    #dt = 10e-12 +    dt = 20e-12 +    print(f'Total energy: {pulses.energy}') +    return pulses + +def create_bitstream(n): +    rng = np.random.default_rng() +    bitstream = rng.integers(low=0, high=2, size=n) +    return bitstream + +def xor_bitstream(bitstream): +    bitstream_a = bitstream +    bitstream_b = [0, *bitstream_a] +    bitstream = [a ^ b for a, b in zip(bitstream_a, bitstream_b)] +    return bitstream + +def bitstream_as_values(bitstream, nbits_per_value=8, pad_with=0): +    padding_length = nbits_per_value - (len(bitstream) % nbits_per_value) +    bitstream = [*bitstream, *[pad_with for i in range(padding_length)]] +     +    values = [] +    for i in range(len(bitstream) // nbits_per_value): +        values.append(int(''.join([str(b) for b in bitstream[i:i+nbits_per_value]]), 2)) +    return values + +def test_bitstreams(): +    bs = create_bitstream(20) +    bs_xor = xor_bitstream(bs) +    values_1bit = bitstream_as_values(bs, 1) +    values_2bit = bitstream_as_values(bs, 2) + +    plt.figure() +    plt.subplot(2,2,1) +    plt.plot(bs) +    plt.subplot(2,2,2) +    plt.plot(bs_xor) +    plt.subplot(2,2,3) +    plt.plot(values_1bit) +    plt.subplot(2,2,4) +    plt.plot(values_2bit) +    plt.show() + +def create_random_bitstream_file(filename, nbits=3000, nbits_per_value=8): +    bs = create_bitstream(nbits) +    #bs_xor = xor_bitstream(bs) +    values = bitstream_as_values(bs, nbits_per_value) +    with open(filename, 'w') as f: +        f.write(' '.join([str(v) for v in values])) + +def compare_Pout_vecs(): +    with open('Pout_nosort.obj', 'r') as f: +        data = pickle.load(f) +    t_sort, Pout_sort = pickle.load(open('Pout_sort.obj')) + +    assert(len(Pout_sort) == len(Pout_nosort)) +    for t, p1, p2 in zip(t_nosort, Pout_nosort, Pout): +        assert(p1 == p2) + +if __name__ == '__main__': +    print(waveguide_length_for_phase_shift(2*pi, 1550e-9, 2.2111)) +    print(waveguide_length_for_phase_shift(pi, 1550e-9, 2.2111)) +    print(waveguide_length_for_phase_shift(1, 1550e-9, 2.2111)) + +    a = waveguide_length_for_phase_shift(2*pi, 1551e-9, 2.2111) +    b = waveguide_length_for_phase_shift(pi, 1551e-9, 2.2111) +    print(a) +    print(b) +    print(500*a+b) +    #exit(0) +    # print(700*waveguide_length_for_phase_shift(2*pi, 1550e-9, 2.2111)) +    # compare_Pout_vecs() +    if len(sys.argv) > 1: +        P = [] +        for fn in sys.argv[1:]: +            P.append(main(filename=fn)) +            P[-1].plot_waveform(1e-13,title=fn, tmax=2.5e-9) + +        if len(sys.argv) > 2: +            plt.figure() +            f0 = P[0].to_waveform(1e-13, tmax=2.5e-9) +            f1 = P[1].to_waveform(1e-13, tmax=2.5e-9) +            plt.plot(f0[0], f1[1] - f0[1]) +        plt.show() +    else: +        main() +        plt.show() +    #test_bitstreams() +    #create_random_bitstream_file('bitstream.txt', 3000) | 
