Lua: Don't lua_error() out of context with pending dtors
[lsnes.git] / src / library / fileimage-patch-bps.cpp
blob02e3990701855e0daad036eb316e1e3ed21ba1af
1 #include "minmax.hpp"
2 #include "fileimage-patch.hpp"
3 #include "serialization.hpp"
4 #include "string.hpp"
5 #include <cstdint>
6 #include <limits>
7 #include <cstring>
8 #include <iostream>
9 #include <zlib.h>
11 namespace fileimage
13 namespace
15 uint8_t readbyte(const char* buf, uint64_t& pos, uint64_t size)
17 if(pos >= size)
18 (stringfmt() << "Attempted to read byte past the end of patch (" << pos << " >= "
19 << size << ").").throwex();
20 return static_cast<uint8_t>(buf[pos++]);
23 uint64_t safe_add(uint64_t a, uint64_t b)
25 if(a + b < a)
26 (stringfmt() << "Integer overflow (" << a << " + " << b << ") processing patch.").throwex();
27 return a + b;
30 uint64_t safe_sub(uint64_t a, uint64_t b)
32 if(a < b)
33 (stringfmt() << "Integer underflow (" << a << " - " << b << ") processing patch.").throwex();
34 return a - b;
37 uint64_t decode_varint(const char* buf, uint64_t& pos, uint64_t size)
39 uint64_t v = 0;
40 size_t i;
41 uint64_t y;
42 for(i = 0; i < 10; i++) {
43 y = readbyte(buf, pos, size) ^ 0x80;
44 v += (y << (7 * i));
45 if(i == 8 && (y | ((v >> 63) ^ 1)) == 255)
46 (stringfmt() << "Varint decoding overlows: v=" << v << " y=" << y << ".").throwex();
47 if(i == 9 && y > 0)
48 (stringfmt() << "Varint decoding overlows: v=" << v << " y=" << y << ".").throwex();
49 if(y < 128)
50 return v;
52 (stringfmt() << "Varint decoding overlows: v=" << v << " y=" << y << ".").throwex();
53 return 0;
56 struct bps_patcher : public patcher
58 ~bps_patcher() throw();
59 bool identify(const std::vector<char>& patch) throw();
60 void dopatch(std::vector<char>& out, const std::vector<char>& original,
61 const std::vector<char>& patch, int32_t offset) throw(std::bad_alloc, std::runtime_error);
62 } bpspatch;
64 bps_patcher::~bps_patcher() throw()
68 bool bps_patcher::identify(const std::vector<char>& patch) throw()
70 return (patch.size() > 4 && patch[0] == 'B' && patch[1] == 'P' && patch[2] == 'S' && patch[3] == '1');
73 void bps_patcher::dopatch(std::vector<char>& out, const std::vector<char>& original,
74 const std::vector<char>& patch, int32_t offset) throw(std::bad_alloc, std::runtime_error)
76 if(offset)
77 (stringfmt() << "Nonzero offsets (" << offset << ") not allowed in BPS mode.").throwex();
78 if(patch.size() < 19)
79 (stringfmt() << "Patch is too masll to be valid BPS patch (" << patch.size()
80 << " < 19).").throwex();
81 uint64_t ioffset = 4;
82 const char* _patch = &patch[0];
83 size_t psize = patch.size() - 12;
84 uint32_t crc_init = crc32(0, NULL, 0);
85 uint32_t pchcrc_c = crc32(crc_init, reinterpret_cast<const uint8_t*>(&patch[0]), patch.size() - 4);
86 uint32_t pchcrc = serialization::u32l(_patch + psize + 8);
87 if(pchcrc_c != pchcrc)
88 (stringfmt() << "CRC mismatch on patch: Claimed: " << pchcrc << " Actual: " << pchcrc_c
89 << ".").throwex();
90 uint32_t srccrc = serialization::u32l(_patch + psize + 0);
91 uint32_t dstcrc = serialization::u32l(_patch + psize + 4);
92 uint64_t srcsize = decode_varint(_patch, ioffset, psize);
93 uint64_t dstsize = decode_varint(_patch, ioffset, psize);
94 uint64_t mdtsize = decode_varint(_patch, ioffset, psize);
95 ioffset += mdtsize;
96 if(ioffset < mdtsize || ioffset > psize)
97 (stringfmt() << "Metadata size invalid: " << mdtsize << "@" << ioffset << ", plimit="
98 << patch.size() << ".").throwex();
100 if(srcsize != original.size())
101 (stringfmt() << "Size mismatch on original: Claimed: " << srcsize << " Actual: "
102 << original.size() << ".").throwex();
103 uint32_t srccrc_c = crc32(crc_init, reinterpret_cast<const uint8_t*>(&original[0]), original.size());
104 if(srccrc_c != srccrc)
105 (stringfmt() << "CRC mismatch on original: Claimed: " << srccrc << " Actual: " << srccrc_c
106 << ".").throwex();
108 out.resize(dstsize);
109 uint64_t target_ptr = 0;
110 uint64_t source_rptr = 0;
111 uint64_t target_rptr = 0;
112 while(ioffset < psize) {
113 uint64_t opc = decode_varint(_patch, ioffset, psize);
114 uint64_t len = (opc >> 2) + 1;
115 uint64_t off = (opc & 2) ? decode_varint(_patch, ioffset, psize) : 0;
116 bool negative = ((off & 1) != 0);
117 off >>= 1;
118 if(safe_add(target_ptr, len) > dstsize)
119 (stringfmt() << "Illegal write: " << len << "@" << target_ptr << ", wlimit="
120 << dstsize << ".").throwex();
121 const char* src;
122 size_t srcoffset;
123 size_t srclimit;
124 const char* msg;
125 switch(opc & 3) {
126 case 0:
127 src = &original[0];
128 srcoffset = target_ptr;
129 srclimit = srcsize;
130 msg = "source";
131 break;
132 case 1:
133 src = &patch[0];
134 srcoffset = ioffset;
135 srclimit = psize - 12;
136 ioffset += len;
137 msg = "patch";
138 break;
139 case 2:
140 if(negative)
141 source_rptr = safe_sub(source_rptr, off);
142 else
143 source_rptr = safe_add(source_rptr, off);
144 src = &original[0];
145 srcoffset = source_rptr;
146 srclimit = srcsize;
147 source_rptr += len;
148 msg = "source";
149 break;
150 case 3:
151 if(negative)
152 target_rptr = safe_sub(target_rptr, off);
153 else
154 target_rptr = safe_add(target_rptr, off);
155 src = &out[0];
156 srcoffset = target_rptr;
157 srclimit = min(dstsize, target_rptr + len);
158 target_rptr += len;
159 msg = "target";
160 break;
162 if(safe_add(srcoffset, len) > srclimit)
163 (stringfmt() << "Illegal read: " << len << "@" << srcoffset << " from " << msg
164 << ", limit=" << srclimit << ".").throwex();
165 for(uint64_t i = 0; i < len; i++)
166 out[target_ptr + i] = src[srcoffset + i];
167 target_ptr += len;
169 if(target_ptr != out.size())
170 (stringfmt() << "Size mismatch on result: Claimed: " << out.size() << " Actual: "
171 << target_ptr << ".").throwex();
172 uint32_t dstcrc_c = crc32(crc_init, reinterpret_cast<const uint8_t*>(&out[0]), out.size());
173 if(dstcrc_c != dstcrc)
174 (stringfmt() << "CRC mismatch on result: Claimed: " << dstcrc << " Actual: " << dstcrc_c
175 << ".").throwex();