# -*- Mode: Python; py-indent-offset: 4 -*-
#
# Copyright (C) 2003,2004,2007  Ray Burr
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
# See: <http://www.gnu.org/copyleft/gpl.html>
#

"""
Python iterators that implement Linear-Feedback Shift Registers.

:authors: Ray Burr
:license: GPL
:contact: http://www.nightmare.com/~ryb/

Examples
========

Create a pseudo-random sequence.  This is maximal-length sequence for
a 4-bit LFSR.  Notice that it repeats every 15 bits.  The
SequenceGenerator returns an iterator, so here it is used to create a
list to store the sequence:

  >>> m = list(SequenceGenerator(4, makePolyMask((4, 3, 0)), length=64))
  >>> print "".join(map(str, m))
  0101100100011110101100100011110101100100011110101100100011110101

Choose a generating polynomial for the scrambler.  This one is used in
the ITU-T V.34 telephone modem standard by the answer mode modem:
 
  >>> gpa = makePolyMask((0, -5, -23), 23)

Scramble the sequence created above:

  >>> s = list(Scrambler(m, 23, gpa))
  >>> print "".join(map(str, s))
  0101101111000000101101110011001101111110111000111011001110010100

Descramble it, and get back the original sequence:

  >>> d = Descrambler(s, 23, gpa)
  >>> print "".join(map(str, d))
  0101100100011110101100100011110101100100011110101100100011110101

Here is another demonstration of generation, scrambling, and
descrambling with a longer 1000 bit test sequence.  This uses a 65-bit
LFSR maximal-length sequence, which would repeat after about 3.7x10^19
bits.  Since the scrambler returns an iterator, scrambling and
descrambling happen simultaneously; the intermediate scrambled
sequence is not stored in a list:

  >>> m = list(SequenceGenerator(65, makePolyMask((65, 47, 0)), length=1000))
  >>> s = Scrambler(m, 23, gpa)
  >>> d = Descrambler(s, 23, gpa)

Check the result.  The scrambling and descrambling happen here:

  >>> list(d) == m
  True
"""

__version__ = "20070612"


import operator


class Scrambler:

    def __init__(self, source, width, taps, seed=0):
        self._width = width
        self._taps = taps
        self._source = iter(source)
        self.value = seed & ((1 << width) - 1)

    def __iter__(self):
        return self

    def next(self):
        b = self._source.next()
        b ^= _parity(self.value & self._taps, self._width)
        self.value |= (b << self._width)
        self.value >>= 1
        return b


class Descrambler:

    def __init__(self, source, width, taps, seed=0):
        self._width = width
        self._taps = taps
        self._source = iter(source)
        self.value = seed & ((1 << width) - 1)

    def __iter__(self):
        return self

    def next(self):
        input = self._source.next()
        self.value |= (input << self._width)
        b = _parity(self.value & self._taps, self._width + 1)
        self.value >>= 1
        return b


class SequenceGenerator:

    def __init__(self, width, taps, seed=-1, xnor=False, length=None):
        self._width = width
        self._taps = taps
        self._xnor = xnor and 1 or 0
        self._length = length
        self.value = seed & ((1 << width) - 1)
        self._count = 0

    def __iter__(self):
        return self

    def next(self):
        self._count += 1
        if (self._length is not None) and (self._count > self._length):
            raise StopIteration
        return self._shift()

    def _shift(self):
        b = _parity(self.value & self._taps, self._width) ^ self._xnor
        self.value |= (b << self._width)
        self.value >>= 1
        return b


def makePolyMask(exponents, offset=0):
    return sum(1 << (x + offset) for x in exponents)

def _parity(x, width):
    p = 0
    for i in range(width):
        p ^= x
        x >>= 1
    return p & 1

def _formatBinaryString(value, width):
    return "".join(
        "01"[(value >> i) & 1] for i in range(width-1,-1,-1))


def _test():
    import doctest, sys
    return doctest.testmod(sys.modules[__name__])

if __name__ == "__main__":
    _test()
