fixed default dict parameter of record.read() persistency
[riffle.git] / shuffle.py
blob780daa3ca9121fc9acc15995bc9c271cc7889780
1 """
2 iPod Shuffle database access
4 Documentation:
5 - http://ipodlinux.org/ITunesDB#iTunesSD_file and further
7 Author: Artem Baguinski
8 """
10 from __future__ import with_statement
11 import struct, os, sys
13 BIG_ENDIAN = True
14 LITTLE_ENDIAN = False
15 READ = 'rb'
16 WRITE = 'w+b'
18 ### Fields
19 class BaseField:
20 def set_bigendian(self, ignore): pass
21 def set_reclen(self, ignore): pass
23 ###### Skip is a dummy field that isn't read
24 class Skip(BaseField):
25 def __init__(self, n):
26 self.get_size = (lambda:n)
27 self.write = self.read
28 def read(self, file, dict):
29 file.seek(self.get_size(), os.SEEK_CUR)
31 ###### Field is a composite field object factory
32 class Field(BaseField):
33 # Value handlers - know what to do with read values and
34 # where to get values to be written
35 class Named:
36 def __init__(self, name, default):
37 self.name = name
38 self.default = default
39 def get(self, dict):
40 return dict[self.name] if dict.has_key(self.name) else self.default
41 def put(self, dict, value):
42 dict[self.name] = value
44 class Const:
45 def __init__(self, const, check):
46 self.const = const
47 if check:
48 self.put = self.check
49 def get(self, dict):
50 return self.const
51 def check(self, dict, value):
52 if value != self.const:
53 raise "Format error"
54 def put(self, dict, value):
55 pass
57 # Field factory - composes field from packer and value handler
58 def __init__(self, packer, name=None, default=None, const=None, check = False):
59 # helpers
60 def set_value_handler(vh):
61 self.put = vh.put
62 self.get = vh.get
63 def const_later(const):
64 set_value_handler(Field.Const(const, check))
66 if name == '%reclen%':
67 self.set_reclen = const_later
68 elif name is not None:
69 set_value_handler(Field.Named(name,default))
70 elif const is not None:
71 set_value_handler(Field.Const(const, check))
72 else:
73 raise "Bad field parameters"
75 self.get_size = packer.get_size
76 self.pack = packer.pack
77 self.unpack = packer.unpack
78 if packer.__class__.__dict__.has_key('set_bigendian'):
79 self.set_bigendian = packer.set_bigendian
81 def read(self, file, dict):
82 self.put(dict, self.unpack( file.read( self.get_size() )))
83 def write(self, file, dict):
84 file.write( self.pack( self.get(dict) ))
86 ### Packers
87 class SimplePacker:
88 def __init__(self, fmt):
89 self.fmt = fmt
90 self.size = struct.calcsize(fmt)
91 def pack(self,val): return struct.pack(self.fmt,val)
92 def unpack(self,str): return struct.unpack(self.fmt,str)[0]
93 def get_size(self): return self.size
95 class Uint8(SimplePacker):
96 def __init__(self): SimplePacker.__init__(self,"B")
98 class Bool8(Uint8):
99 def pack(self, val): return Uint8.pack(self, (1 if val else 0))
100 def unpack(self, str): return Uint8.unpack(self, str) != 0
102 class Uint24:
103 def __init__(self, bigendian = LITTLE_ENDIAN):
104 self.bigendian = bigendian
105 def get_size(self): return 3
106 def pack(self, i):
107 if self.bigendian:
108 return struct.pack(">I",i)[1:4]
109 else:
110 return struct.pack("<I",i)[0:3]
111 def unpack(self, s):
112 if self.bigendian:
113 return struct.unpack('>I','\x00' + s[0:3])[0]
114 else:
115 return struct.unpack('<I',s[0:3] + '\x00')[0]
116 def set_bigendian(self, bigendian):
117 self.bigendian = bigendian
119 class Int24(Uint24):
120 def __init__(self, bigendian = LITTLE_ENDIAN):
121 self.bigendian = bigendian
122 def pack(self, i):
123 if self.bigendian:
124 return struct.pack(">i",i)[1:4]
125 else:
126 return struct.pack("<i",i)[0:3]
127 def unpack(self, s):
128 u = Uint24.unpack(self,s)
129 if (u & 0x800) != 0:
130 return - ((~u + 1) & 0xfff)
131 else:
132 return u
134 class Bool24(Int24):
135 def __init__(self): Int24.__init__(self)
136 def pack(self, val): return Int24.pack(self, (-1 if val else 0))
137 def unpack(self, str): return Int24.unpack(self, str) != 0
139 class ZeroPaddedString:
140 def __init__(self, len, enc):
141 self.size = len
142 self.enc = enc
143 def pack(self, val):
144 return val.encode(self.enc).ljust(self.size,'\x00')
145 def unpack(self, str):
146 return str.decode(self.enc).rstrip('\x00')
147 def get_size(self): return self.size
149 ### Record - an ordered list of fields
150 class Record:
151 def __init__(self, fields, bigendian):
152 self.fields = fields
153 reclen = self.get_size()
154 for f in fields:
155 f.set_bigendian(bigendian)
156 f.set_reclen(reclen)
158 def read(self, file, dict=None):
159 if dict is None: dict = {}
160 for f in self.fields:
161 f.read(file, dict)
162 return dict
164 def write(self, file, dict):
165 for f in self.fields:
166 f.write(file, dict)
168 def get_size(self):
169 size = 0
170 for f in self.fields:
171 size += f.get_size()
172 return size
174 class Track:
175 supported_file_types = (".mp3", ".aa", ".m4a", ".m4b", ".m4p", ".wav")
177 starttime = 0
178 stoptime = 0
179 volume = 0x64
180 bookmarktime = -1
181 playcount = 0
182 skippedcount = 0
183 filename = None
184 file_type = 0
185 bookmarkflag = False
186 shuffleflag = True
188 def set_filename(self, filename):
189 self.filename = filename
190 if filename.endswith((".mp3",".aa")):
191 self.file_type = 1
192 elif filename.endswith((".m4a", ".m4b", ".m4p")):
193 self.file_type = 2
194 elif filename.endswith(".wav"):
195 self.file_type = 4
196 else:
197 raise "%s: unsupported file type" % (filename)
198 if filename.endswith((".aa",".m4b")):
199 self.bookmarkflag = True
200 else:
201 self.bookmarkflag = False
202 self.shuffleflag = not self.bookmarkflag
204 def __str__(self):
205 s = "%s\n vol: %d " % (self.filename, self.volume)
206 if self.starttime != 0 or self.stoptime != 0:
207 s += "%5.3fs-%5.3fs " % (self.starttime*0.256, self.stoptime*0.256)
208 if self.bookmarkflag:
209 bm = self.bookmarktime
210 if bm<0:
211 bm=0
212 s += "bookmark: %5.3fs " % (bm*0.256)
213 if self.shuffleflag:
214 s += "shuffle "
215 s += "played: %d skipped: %d" % (self.playcount, self.skippedcount)
216 return s
218 # persistency
219 old_tracks = {}
221 @classmethod
222 def new(cls, filename):
223 if Track.old_tracks.has_key(filename):
224 return Track.old_tracks[filename]
225 else:
226 t = cls()
227 t.set_filename(filename)
228 return t
230 @classmethod
231 def set_old_tracks(cls, lst):
232 cls.old_tracks = {}
233 for i in xrange(len(lst)):
234 t = lst[i]
235 cls.old_tracks[t.filename] = t
236 cls.old_tracks[i] = t
238 class PState:
239 volume = 29
240 shufflepos = 0
241 trackno = 0
242 shuffleflag = False
243 trackpos = 0
245 def __str__(self):
246 return """Player state:
247 volume: %d
248 shuffle mode: %s
249 shuffle position: %d
250 track number: %d
251 track position: %d""" % (self.volume, self.shuffleflag,
252 self.shufflepos, self.trackno, self.trackpos)
254 class ShuffleDB:
255 iTunesSD_hdr = Record([
256 Field(Uint24(), 'tracks'),
257 Field(Uint24(), const=0x010800),
258 Field(Uint24(), '%reclen%', check=True),
259 Skip(9)],
260 BIG_ENDIAN)
262 iTunesSD_track = Record([
263 Field(Uint24(), '%reclen%', check=True),
264 Skip(3),
265 Field(Uint24(), 'starttime'),
266 Skip(6),
267 Field(Uint24(), 'stoptime'),
268 Skip(6),
269 Field(Uint24(), 'volume'),
270 Field(Uint24(), 'file_type'),
271 Skip(3),
272 Field(ZeroPaddedString(522, 'UTF-16-LE'), 'filename'),
273 Field(Bool8(), 'shuffleflag'),
274 Field(Bool8(), 'bookmarkflag'),
275 Skip(1)],
276 BIG_ENDIAN)
278 iTunesStats_hdr = Record([
279 Field( Uint24(), 'tracks'),
280 Skip(3)],
281 LITTLE_ENDIAN)
283 iTunesStats_track = Record([
284 Field( Uint24(), '%reclen%', check = True),
285 Field( Int24(), 'bookmarktime'),
286 Skip(6),
287 Field( Uint24(), 'playcount'),
288 Field( Uint24(), 'skippedcount')],
289 LITTLE_ENDIAN)
291 iTunesPState = Record([
292 Field( Uint8(), 'volume' ),
293 Field( Uint24(), 'shufflepos' ),
294 Field( Uint24(), 'trackno' ),
295 Field( Bool24(), 'shuffleflag'),
296 Field( Uint24(), 'trackpos'),
297 Skip(19)],
298 LITTLE_ENDIAN)
300 def write_iTunesSD(self, tracks):
301 with open('iTunesSD', WRITE) as file:
302 self.iTunesSD_hdr.write(file, {'tracks':len(tracks)})
303 for t in tracks:
304 self.iTunesSD_track.write(file, t.__dict__)
305 file.truncate()
307 def read_iTunesSD(self):
308 with open('iTunesSD', READ) as file:
309 num_tracks = self.iTunesSD_hdr.read(file)['tracks']
310 tracks = []
311 for n in xrange(0, num_tracks):
312 t = Track()
313 self.iTunesSD_track.read(file, t.__dict__)
314 tracks.append( t )
315 return tracks
317 def write_iTunesStats(self, tracks):
318 with open('iTunesStats', WRITE) as file:
319 self.iTunesStats_hdr.write(file, {'tracks':len(tracks)})
320 for t in tracks:
321 self.iTunesStats_track.write(file, t.__dict__)
322 file.truncate()
324 def read_iTunesStats(self, tracks):
325 with open('iTunesStats', READ) as file:
326 num_tracks = self.iTunesStats_hdr.read(file)['tracks']
327 if num_tracks != len(tracks):
328 raise "Inconsistent number of songs in iTunesSD and iTunesStats"
329 for t in tracks:
330 self.iTunesStats_track.read(file, t.__dict__)
332 def write_iTunesPState(self, pstate):
333 mode = 'r+b' if os.path.exists('iTunesPState') else WRITE
334 with open('iTunesPState', mode) as file:
335 self.iTunesPState.write(file, pstate.__dict__)
336 file.truncate()
338 def read_iTunesPState(self):
339 pstate = PState()
340 with open('iTunesPState', READ) as file:
341 self.iTunesPState.read(file, pstate.__dict__)
342 return pstate
344 def read_all(self):
345 tracks = self.read_iTunesSD()
346 self.read_iTunesStats(tracks)
347 pstate = self.read_iTunesPState()
348 return (tracks, pstate)
350 def write_all(self, tracks, pstate):
351 self.write_iTunesSD(tracks)
352 self.write_iTunesStats(tracks)
353 self.write_iTunesPState(pstate)
355 #####################################################################
356 if __name__ == '__main__':
357 def print_list(xs):
358 for x in xs:
359 print x
361 if len(sys.argv) > 1:
362 # try reading
363 start_dir = os.getcwd()
364 os.chdir(sys.argv[1])
365 db = ShuffleDB()
366 tracks, pstate = db.read_all()
367 print_list( tracks )
368 print pstate
370 # try cache
371 Track.set_old_tracks( tracks )
372 t = Track.new( "foo.mp3" )
373 print t
374 t = Track.new( tracks[0].filename )
375 print t
377 if len(sys.argv) > 2:
378 # try writing
379 os.chdir(start_dir)
380 os.chdir(sys.argv[2])
381 db.write_all(tracks, pstate)