Clean up library includes
[lsnes.git] / src / library / patch-bps.cpp
blob279450a625d58171795dd34987fb3258a56f65f7
1 #include "minmax.hpp"
2 #include "patch.hpp"
3 #include "serialization.hpp"
4 #include "string.hpp"
5 #include <cstdint>
6 #include <limits>
7 #include <cstring>
8 #include <iostream>
9 #include <zlib.h>
11 namespace
13 uint8_t readbyte(const char* buf, uint64_t& pos, uint64_t size)
15 if(pos >= size)
16 (stringfmt() << "Attempted to read byte past the end of patch (" << pos << " >= "
17 << size << ").").throwex();
18 return static_cast<uint8_t>(buf[pos++]);
21 uint64_t safe_add(uint64_t a, uint64_t b)
23 if(a + b < a)
24 (stringfmt() << "Integer overflow (" << a << " + " << b << ") processing patch.").throwex();
25 return a + b;
28 uint64_t safe_sub(uint64_t a, uint64_t b)
30 if(a < b)
31 (stringfmt() << "Integer underflow (" << a << " - " << b << ") processing patch.").throwex();
32 return a - b;
35 uint64_t decode_varint(const char* buf, uint64_t& pos, uint64_t size)
37 uint64_t v = 0;
38 size_t i;
39 uint64_t y;
40 for(i = 0; i < 10; i++) {
41 y = readbyte(buf, pos, size) ^ 0x80;
42 v += (y << (7 * i));
43 if(i == 8 && (y | ((v >> 63) ^ 1)) == 255)
44 (stringfmt() << "Varint decoding overlows: v=" << v << " y=" << y << ".").throwex();
45 if(i == 9 && y > 0)
46 (stringfmt() << "Varint decoding overlows: v=" << v << " y=" << y << ".").throwex();
47 if(y < 128)
48 return v;
52 struct bps_patcher : public rom_patcher
54 ~bps_patcher() throw();
55 bool identify(const std::vector<char>& patch) throw();
56 void dopatch(std::vector<char>& out, const std::vector<char>& original,
57 const std::vector<char>& patch, int32_t offset) throw(std::bad_alloc, std::runtime_error);
58 } bpspatch;
60 bps_patcher::~bps_patcher() throw()
64 bool bps_patcher::identify(const std::vector<char>& patch) throw()
66 return (patch.size() > 4 && patch[0] == 'B' && patch[1] == 'P' && patch[2] == 'S' && patch[3] == '1');
69 void bps_patcher::dopatch(std::vector<char>& out, const std::vector<char>& original,
70 const std::vector<char>& patch, int32_t offset) throw(std::bad_alloc, std::runtime_error)
72 if(offset)
73 (stringfmt() << "Nonzero offsets (" << offset << ") not allowed in BPS mode.").throwex();
74 if(patch.size() < 19)
75 (stringfmt() << "Patch is too masll to be valid BPS patch (" << patch.size()
76 << " < 19).").throwex();
77 uint64_t ioffset = 4;
78 const char* _patch = &patch[0];
79 size_t psize = patch.size() - 12;
80 uint32_t crc_init = crc32(0, NULL, 0);
81 uint32_t pchcrc_c = crc32(crc_init, reinterpret_cast<const uint8_t*>(&patch[0]), patch.size() - 4);
82 uint32_t pchcrc = read32ule(_patch + psize + 8);
83 if(pchcrc_c != pchcrc)
84 (stringfmt() << "CRC mismatch on patch: Claimed: " << pchcrc << " Actual: " << pchcrc_c
85 << ".").throwex();
86 uint32_t srccrc = read32ule(_patch + psize + 0);
87 uint32_t dstcrc = read32ule(_patch + psize + 4);
88 uint64_t srcsize = decode_varint(_patch, ioffset, psize);
89 uint64_t dstsize = decode_varint(_patch, ioffset, psize);
90 uint64_t mdtsize = decode_varint(_patch, ioffset, psize);
91 ioffset += mdtsize;
92 if(ioffset < mdtsize || ioffset > psize)
93 (stringfmt() << "Metadata size invalid: " << mdtsize << "@" << ioffset << ", plimit="
94 << patch.size() << ".").throwex();
96 if(srcsize != original.size())
97 (stringfmt() << "Size mismatch on original: Claimed: " << srcsize << " Actual: "
98 << original.size() << ".").throwex();
99 uint32_t srccrc_c = crc32(crc_init, reinterpret_cast<const uint8_t*>(&original[0]), original.size());
100 if(srccrc_c != srccrc)
101 (stringfmt() << "CRC mismatch on original: Claimed: " << srccrc << " Actual: " << srccrc_c
102 << ".").throwex();
104 out.resize(dstsize);
105 uint64_t target_ptr = 0;
106 uint64_t source_rptr = 0;
107 uint64_t target_rptr = 0;
108 while(ioffset < psize) {
109 uint64_t opc = decode_varint(_patch, ioffset, psize);
110 uint64_t len = (opc >> 2) + 1;
111 uint64_t off = (opc & 2) ? decode_varint(_patch, ioffset, psize) : 0;
112 bool negative = ((off & 1) != 0);
113 off >>= 1;
114 if(safe_add(target_ptr, len) > dstsize)
115 (stringfmt() << "Illegal write: " << len << "@" << target_ptr << ", wlimit="
116 << dstsize << ".").throwex();
117 const char* src;
118 size_t srcoffset;
119 size_t srclimit;
120 const char* msg;
121 switch(opc & 3) {
122 case 0:
123 src = &original[0];
124 srcoffset = target_ptr;
125 srclimit = srcsize;
126 msg = "source";
127 break;
128 case 1:
129 src = &patch[0];
130 srcoffset = ioffset;
131 srclimit = psize - 12;
132 ioffset += len;
133 msg = "patch";
134 break;
135 case 2:
136 if(negative)
137 source_rptr = safe_sub(source_rptr, off);
138 else
139 source_rptr = safe_add(source_rptr, off);
140 src = &original[0];
141 srcoffset = source_rptr;
142 srclimit = srcsize;
143 source_rptr += len;
144 msg = "source";
145 break;
146 case 3:
147 if(negative)
148 target_rptr = safe_sub(target_rptr, off);
149 else
150 target_rptr = safe_add(target_rptr, off);
151 src = &out[0];
152 srcoffset = target_rptr;
153 srclimit = min(dstsize, target_rptr + len);
154 target_rptr += len;
155 msg = "target";
156 break;
158 if(safe_add(srcoffset, len) > srclimit)
159 (stringfmt() << "Illegal read: " << len << "@" << srcoffset << " from " << msg
160 << ", limit=" << srclimit << ".").throwex();
161 for(uint64_t i = 0; i < len; i++)
162 out[target_ptr + i] = src[srcoffset + i];
163 target_ptr += len;
165 if(target_ptr != out.size())
166 (stringfmt() << "Size mismatch on result: Claimed: " << out.size() << " Actual: "
167 << target_ptr << ".").throwex();
168 uint32_t dstcrc_c = crc32(crc_init, reinterpret_cast<const uint8_t*>(&out[0]), out.size());
169 if(dstcrc_c != dstcrc)
170 (stringfmt() << "CRC mismatch on result: Claimed: " << dstcrc << " Actual: " << dstcrc_c
171 << ".").throwex();