# demfm.py
#
# MFM bitstream parser/decoder
#
# Written by Denis Dratov <volutar@gmail.com>
#
# This is free and unencumbered software released into the public domain.
# See the file COPYING for more details, or visit <http://unlicense.org>.
import sys

class DeMFM:
  """De-MFM and logical parse of unsynchronized MFM bitstream"""
  crctable = [] #shared list
  _curcrc = 0xffff
  bytes=bytearray()
  syncs=bytearray()
  cat=list()

  def __init__(self):
    poly = 0x1021
    if len(self.crctable)>0: #already created
      return
    for byte in range(256):
      w = byte << 8
      for bit in range(8):
        if (byte ^ w) & 0x8000:
          w = (w << 1) ^ poly
        else:
          w <<= 1
      w &= 0xffff
      self.crctable.append(w)

  @classmethod
  def crc16add(self,byte):
    self._curcrc = self.crctable[((self._curcrc >> 8) ^ byte) & 0xff] ^ (self._curcrc << 8) & 0xffff

  @staticmethod
  def unmfm(mfm):
    bt=0
    if mfm&0x0001: bt|=0x01
    if mfm&0x0004: bt|=0x02
    if mfm&0x0010: bt|=0x04
    if mfm&0x0040: bt|=0x08
    if mfm&0x0100: bt|=0x10
    if mfm&0x0400: bt|=0x20
    if mfm&0x1000: bt|=0x40
    if mfm&0x4000: bt|=0x80
    return bt

# deshuffle odd/even array in amiga manner
  @staticmethod
  def ami_unshuffle(bytes,start,ln):
    rt=bytearray()
    p1=start
    p2=start+ln
    for i in range(ln):
      a1=bytes[p1]
      a2=bytes[p2]
      dec=0
      for j in range(8):
        dec<<=2
        if a1&0x80: dec|=2
        if a2&0x80: dec|=1
        a1<<=1
        a2<<=1
      rt.append(dec>>8)
      rt.append(dec&0xff)
      p1+=1
      p2+=1
    return rt

# shuffle odd/even array in amiga manner
  @staticmethod
  def ami_shuffle(bytes,start,ln):
    rt1=bytearray()
    rt2=bytearray()
    p1=start
    for i in range(ln):
      dec=(bytes[p1]<<8) | bytes[p1+1]
      a1=a2=0
      for j in range(8):
        a1<<=1
        a2<<=1
        if dec&0x8000: a1|=1
        if dec&0x4000: a2|=1
        dec<<=2
      rt1.append(a1)
      rt2.append(a2)
      p1+=2
    return rt1+rt2


# Main MFM decoding and parsing procedure
# in:
#  mfm encoded byte list
# out:
#  -bytes = decoded byte list (synced)
#  -syncs = bitmap (1 bit per byte, val = (syncs[x//8]>>(x%8))&1)
#  -cat = track catalog structure (to use with additional parsing and converting):
#    -type: 1=IDAM, 2=DAM -1=unknown
#    -id: sync ID (FB/FE)
#    -offset: data offset in decoded byte list
#    -datalen: data length in decoded byte list (and CRC after, if not interrupted)
#    -crc: CRC 1=OK 0=BAD -1=NO CRC (early interruption)
#    -cylinder: 0-84
#    -side: 0-1
#    -sector: 0-255 
#
  @classmethod
  def decode_mfm(self,dat):
    self.bytes=bytearray()
    self.syncs=bytearray()
    self.cat=[]
    curbyte=0
    bit=16
    mark=0
    sync=0
    sync_cnt=8
    runstage=0
    runarea=0
    readcrc=0
    cur_type=0
    cur_id=0
    cur_offset=0
    cur_cyl=0
    cur_side=0
    cur_sec=0
    cur_seclen=0
    idx=-1
    runbit=0
    amiga=False
    amibuf=bytearray()
    amidec=bytearray()

    for x in dat:
      idx+=1
      for b in range(8):
        curbyte<<=1
        if x&0x01: curbyte|=1
        x>>=1
        curbyte&=0xffff # 16bit
        bit-=1

        #sync marks
        if runstage==0 and curbyte==0x9254:  #1001001001010100 #sync gap 4E
          mark=0 
          bit=0
        if curbyte==2*0x5224: #1010010001001000 #sync C2 - syncing next byte
          # after syncing
          bit=15
          mark=2
        if curbyte==0x4489:   #0100010010001001 #sync A1
          if runstage==0: #sync in true gap (excluding sector in sector)
            if b==0: j=1
            else: j=0

            #traversing back from sync, collecting  twelve zeroes (part of MFM standard)
            k=0
            z=-1
            ormask=0x10000<<b
            for y in range(4,28):
              if len(self.bytes)<12: break
              j=dat[idx-y]
              for i in range(8):
                k>>=1
                if j&0x80: k|=ormask
                j<<=1
              if y%2==1: continue

              if y>1: #fixing pre-sync 0x00 to proper 0x00
                l=self.unmfm(k);
#                print('%.2x' %l,end='')
                if l==0: self.bytes[z]=0
                if z==-12 and l&0x3f==0: self.bytes[z]=0 #first byte of 12x0 may be incomplete
#                print('(%.2x)' %self.bytes[z],end=' ')
                z-=1
#            print()

          mark=1
          bit=0
          if runarea>0 and runstage>3: #ISSUE: sync inside data, early stop
            runstage=0 
            self.cat.append([cur_type,cur_id,cur_offset,len(self.bytes)-cur_offset,-1,cur_cyl,cur_side,cur_sec])
          if runstage==4:
            runstage=0
          elif runstage==3:
            runstage=0
          elif runstage==2:
            runstage=3
          elif runstage==1:
            runstage=2
          elif runstage==0:
            self._curcrc=0xffff #init crc
            readcrc=0
            runarea=4+4+2
            runstage=1
#            print('\nSYNC: ',end='')
        elif bit==0: #not between bytes
          if runstage>=1 and runstage<=2:
            if runstage==2: #amiga disk
              amiga=True
              amibuf=bytearray()
              cur_offset=len(self.bytes)
#              print('offs',cur_offset)
              runarea=20+2+2+512+4
              runstage=4
            else:
            #print('SYNC break',runstage)
              runstage=0
              runarea=0

        if bit==0:
          decbyte=self.unmfm(curbyte)
          self.bytes.append(decbyte)
          bit=16
#        print(format(decbyte,'02X'),end=' ')

          #sync bitmap
          if mark==2: sync|=0x80
          sync>>=1
          if mark==1: sync|=0x80
          sync_cnt-=1
          if sync_cnt==0:
            self.syncs.append(sync)
            sync_cnt=8
          mark=0
        else:
          continue #between bytes

        if amiga:
          if runstage==4: #header
            if runarea==0:
              amidec=self.ami_unshuffle(amibuf,0,2)
              if amidec[0]!=0xff: #skip - not amiga!!!
                amiga=False
                runstage=0
                runarea=0
              else:
                cur_cyl=amidec[1]//2
                cur_side=amidec[1]%2
                cur_sec=amidec[2]
#                print(' - %02x:%d %02x' %(cur_cyl,cur_side,cur_sec))
#                print('%d, %04X' %(len(amibuf),self._curcrc))

                #use crc bytes as crced data, so it should compensate to 0
                self._curcrc=0
                for i in range(0,22+2,2):
                  self._curcrc^=((amibuf[i]<<8)+amibuf[i+1])
                if self._curcrc==0: cur_crc=1 #header crc - offset 22,23
                else: cur_crc=0
                self.cat.append([1,1,cur_offset,22,cur_crc,cur_cyl,cur_side,cur_sec])


#                self._curcrc=(amibuf[26]<<8)+amibuf[27]
#                print(' %04X' %(self._curcrc))

                self._curcrc=0
                for i in range(26,26+2+512,2):
                  self._curcrc^=((amibuf[i]<<8)+amibuf[i+1])
                if self._curcrc==0: cur_crc=1 #header crc - offset 26,27
                else: cur_crc=0
                self.cat.append([2,2,cur_offset+28,512,cur_crc,cur_cyl,cur_side,cur_sec])

                '''
                amidec=self.ami_unshuffle(amibuf,28,256) #unshuffle
                #update received bytearray with unshuffled sector data
                for i in range(len(amidec)):
#                  print(format(amidec[i],'02X'),end=' ')
                  self.bytes[-513+i]=amidec[i]
                '''

#                for i in range(cur_offset+28,cur_offset+28+32):
#                  ch=self.bytes[i]
#                  print('%02x '%ch,end='')

                runstage=0

            if runarea>=1:
              amibuf.append(decbyte)
              runarea-=1
#              if runstage<24:

        else:
          if runstage>0:
            if runstage==8: #seclen
              cur_seclen=128*(1<<decbyte) # &3 would emulate 2bit WD1772 behavior
#              print('(IDAM cyl=%d side=%d sec=%d seclen=%d)' %(cur_cyl,cur_side,cur_sec,cur_seclen),end='')
              runstage+=1
            elif runstage==7: #sec
              cur_sec=decbyte
              runstage+=1
            elif runstage==6: #side
              cur_side=decbyte
              runstage+=1
            elif runstage==5: #IDAM -> cyl
              cur_cyl=decbyte
              runstage+=1
            elif runstage==4: #after A1 sync 
              cur_offset=len(self.bytes)
              cur_id=decbyte
              if decbyte>=0xfc: #IDAM
                cur_type=1
                runarea=4+3
                runstage+=1
              elif decbyte>=0xf8: #DAM
                cur_type=2
##                print('(DAM %d: %d bytes)' %(len(bytes),cur_seclen))
                if cur_seclen==0: #displaced DAM
                  runstage=-1
                else:
                  runarea=cur_seclen+3
                  runstage=9
              else:  #illegal code (not fb-ff)
                cur_type=-1
                runstage=0
            elif runstage==3: #skip third A1
              runstage+=1

            if cur_type==-1: #unknown sync type
              self.cat.append([cur_type,cur_id,cur_offset,0,0,0,0,0])
              cur_type=0

            if runarea==1: #CRC collected
              readcrc|=decbyte
              if readcrc==self._curcrc:
#                print('OK!')
                cur_crc=1
              else:
                cur_crc=0
#                print('BAD(%.4X)' %self._curcrc)
              if cur_type==1:
                datalen=4
              else:
                datalen=cur_seclen
              self.cat.append([cur_type,cur_id,cur_offset,datalen,cur_crc,cur_cyl,cur_side,cur_sec])
#                cur_cyl=0
#                cur_side=0
#                cur_sec=0
#                cur_seclen=0 #only once
              cur_type=0 #clear catalog item id
              cur_id=0
              runstage=0
              runarea=0 
            elif runarea==2: #first byte of CRC
              readcrc=decbyte<<8
              runarea=1
#            if runarea==3: #data is over - CRC data begin
#              print('\nCRC: ',end='')
            elif runarea>=3:
              self.crc16add(decbyte)
              runarea-=1

    if sync_cnt>0 and sync_cnt<8: #sync bitmap leftovers
      while sync_cnt>0:
        sync_cnt-=1
        sync>>=1
      self.syncs.append(sync)

    if runstage!=0: #unfinished chunk
#      print(runstage,runarea,cur_offset)
      self.cat.append([cur_type,cur_id,cur_offset,len(self.bytes)-cur_offset,-1,cur_cyl,cur_side,cur_sec])

#    if len(self.cat)==0 and amiga:
#      self.cat.append(['ami'])
  
#  for x in range(0,len(bytes)):
#    print(format(bytes[x],"02X"),end=".")
#    print((syncs[x//8]>>(x%8))&1,end=" ")
  '''
  print("\n",len(bytes))
  for x in range(0,len(syncs)):
    print(format(syncs[x],"02X"),end=" ")
  print("\n",len(syncs))
  '''

#  print('\n')
#    return self.bytes,self.syncs,self.cat
#    print(format(x,"02X"),end=" ")

  def print_cat(self):
    for i in self.cat:
      if i[0]=='ami': break
      if i[0]==2: s='\t'
      else: s=''
      print(s,end='')
      k=i[3]
      if k>8: k=8
      for j in range(k):
        print('%.2X' %self.bytes[i[2]+j],end=' ')
      print('-',i)
#      print(s,end='')

    for x in range(0,len(self.bytes)):
      print(format(self.bytes[x],"02X"),end=" ")
#      print((self.syncs[x//8]>>(x%8))&1,end=" ")

    print ('len=',len(self.bytes))
    print

  def print_cat_short(self):
    cyl=-1
    side=-1
    sec=-1
    slen=-1
    head=False
    for i in self.cat:
#      if i[0]=='ami':
#        print('Amiga track')
#        return
      if i[0]==1: #IDAM
        if cyl!=i[5] or side!=i[6]:
          cyl=i[5]
          side=i[6]
          print ('T%.3d:%d' %(cyl,side),end=' ')
        sec=i[7]
        head=True
      if i[0]==2: #DAM
        if slen==-1:
          print ('len%d' %i[3],end='  ')
          slen=i[3]

        if sec>=0: ssec=format(sec,"02X")
        else: ssec="_"
        if slen!=i[3]:
          print('[%d]'%i[3],end='')
        if i[4]==1: #CRC ok
          print('%s' %(ssec), end=' ')
        elif i[4]==0: #CRC error
          print('%s!' %(ssec), end=' ')
        elif i[4]==-1: #broken sector
          print('%s\\' %(ssec), end=' ')
        sec=-1
    print()
#    print ('bytes:',sys.getsizeof(self.bytes))
#    print ('syncs:',sys.getsizeof(self.syncs))
#    print ('cat:',sys.getsizeof(self.cat))
