lib/compression: Fix length check
[Samba.git] / python / samba / tests / krb5 / xpress.py
blobb0fbe26fafb0d64bfe2ef41d36c39ea948250af9
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Catalyst.Net Ltd 2022
4 # This program is free software; you can redistribute it and/or modify
5 # it under the terms of the GNU General Public License as published by
6 # the Free Software Foundation; either version 3 of the License, or
7 # (at your option) any later version.
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
14 # You should have received a copy of the GNU General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
18 from samba.dcerpc import claims
21 def decompress(data, compression_type, uncompressed_size):
22 if compression_type == claims.CLAIMS_COMPRESSION_FORMAT_NONE:
23 return data
24 elif compression_type == claims.CLAIMS_COMPRESSION_FORMAT_XPRESS_HUFF:
25 return lz77_huffman_decompress(data, uncompressed_size)
26 else:
27 raise AssertionError(f'compression type {compression_type} '
28 f'not supported')
31 def lz77_huffman_decompress(data, decompressed_size):
32 def get_16_bits(data, pos):
33 return data[pos] + (data[pos + 1] << 8)
35 output = []
37 symbol_bit_lengths = []
38 for pair in data[:256]:
39 symbol_bit_lengths.append(pair & 0xf)
40 symbol_bit_lengths.append(pair >> 4)
42 # Loop until a decompression terminating condition.
43 while True:
44 # Build the decoding table.
45 decoding_table = []
46 for bit_len in range(1, 16):
47 for symbol in range(0, 512):
48 encoded_bit_length = symbol_bit_lengths[symbol]
49 if encoded_bit_length == bit_len:
50 count = (1 << (15 - bit_len))
51 decoding_table.extend([symbol] * count)
53 if len(decoding_table) != 2 ** 15:
54 raise AssertionError(f'Error constructing decoding table (len = '
55 f'{len(decoding_table)}')
57 # Start at the end of the Huffman table.
58 current_pos = 256
60 next_bits = get_16_bits(data, current_pos)
61 current_pos += 2
63 next_bits <<= 16
64 next_bits |= get_16_bits(data, current_pos)
65 current_pos += 2
67 extra_bit_count = 16
68 block_end = len(output) + 65536
70 # Loop until a block terminating condition.
71 while len(output) < block_end:
72 huffman_symbol = decoding_table[next_bits >> (32 - 15)]
74 huffman_symbol_bit_len = symbol_bit_lengths[huffman_symbol]
75 next_bits <<= huffman_symbol_bit_len
76 next_bits &= 0xffffffff
77 extra_bit_count -= huffman_symbol_bit_len
79 if extra_bit_count < 0:
80 next_bits |= get_16_bits(data, current_pos) << -extra_bit_count
81 extra_bit_count += 16
82 current_pos += 2
84 if huffman_symbol < 256:
85 output.append(huffman_symbol)
87 elif (huffman_symbol == 256 and current_pos == len(data)
88 and len(output) == decompressed_size):
89 return bytes(output)
90 else:
91 huffman_symbol -= 256
93 match_len = huffman_symbol & 0xf
94 match_offset_bit_len = huffman_symbol >> 4
96 if match_len == 15:
97 match_len = data[current_pos]
98 current_pos += 1
100 if match_len == 255:
101 match_len = get_16_bits(data, current_pos)
102 current_pos += 2
104 if match_len < 15:
105 raise AssertionError(f'match_len is too small! '
106 f'({match_len} < 15)')
107 match_len -= 15
108 match_len += 15
109 match_len += 3
111 match_offset = next_bits >> (32 - match_offset_bit_len)
112 match_offset |= 1 << match_offset_bit_len
114 next_bits <<= match_offset_bit_len
115 next_bits &= 0xffffffff
117 extra_bit_count -= match_offset_bit_len
118 if extra_bit_count < 0:
119 next_bits |= (
120 get_16_bits(data, current_pos) << -extra_bit_count)
121 extra_bit_count += 16
122 current_pos += 2
124 for i in range(len(output) - match_offset,
125 len(output) - match_offset + match_len):
126 output.append(output[i])
128 raise AssertionError('Should not get here')