Fix integer overflow in ft_rendered_size_line
[ilaris-y4m-tools.git] / fftaa.cpp
blob0e5ffd15fc3e7596d38f7f05d69263b46f1420d6
1 #include <limits>
2 #include <cstdint>
3 #include <complex>
4 #include <vector>
5 #include <fftw3.h>
6 #include <cstring>
7 #include <stdexcept>
8 #include <iostream>
9 #include "parseval.hpp"
10 #include "yuv4mpeg.hpp"
13 fftw_plan forward_plan;
14 fftw_plan backward_plan;
15 fftw_plan forward_plan2;
16 fftw_plan backward_plan2;
17 std::complex<double>* input_memory;
18 std::complex<double>* output_memory;
20 void allocate_ffts(size_t iw, size_t ih, size_t ow, size_t oh, size_t wf, size_t hf)
22 input_memory = reinterpret_cast<std::complex<double>*>(fftw_alloc_complex(iw * ih));
23 output_memory = reinterpret_cast<std::complex<double>*>(fftw_alloc_complex(iw * ih));
24 forward_plan = fftw_plan_dft_2d(ih, iw, reinterpret_cast<fftw_complex*>(input_memory),
25 reinterpret_cast<fftw_complex*>(output_memory), -1, FFTW_MEASURE);
26 backward_plan = fftw_plan_dft_2d(oh, ow, reinterpret_cast<fftw_complex*>(input_memory),
27 reinterpret_cast<fftw_complex*>(output_memory), +1, FFTW_MEASURE);
28 forward_plan2 = fftw_plan_dft_2d(ih / wf, iw / hf, reinterpret_cast<fftw_complex*>(input_memory),
29 reinterpret_cast<fftw_complex*>(output_memory), -1, FFTW_MEASURE);
30 backward_plan2 = fftw_plan_dft_2d(oh / wf, ow / hf, reinterpret_cast<fftw_complex*>(input_memory),
31 reinterpret_cast<fftw_complex*>(output_memory), +1, FFTW_MEASURE);
34 void copyrow(std::complex<double>* target, const std::complex<double>* source, size_t isize, size_t osize)
36 size_t lsize = (osize + 1) >> 1;
37 size_t rsize = (osize - 1) >> 1;
38 memcpy(target, source, lsize * sizeof(std::complex<double>));
39 memcpy(target + osize - rsize, source + isize - rsize, rsize * sizeof(std::complex<double>));
40 if((osize & 1) == 0)
41 target[osize / 2] = 0;
44 void copyimage(std::complex<double>* target, std::complex<double>* source, size_t isize, size_t osize,
45 size_t irows, size_t orows)
47 size_t usize = (orows + 1) >> 1;
48 size_t lsize = (orows - 1) >> 1;
50 for(size_t i = 0; i < usize; i++)
51 copyrow(target + i * osize, source + i * isize, isize, osize);
52 for(size_t i = 0; i < lsize; i++)
53 copyrow(target + (orows - lsize + i) * osize, source + (irows - lsize + i) * isize, isize, osize);
54 if((orows & 1) == 0)
55 for(size_t i = 0; i < osize; i++)
56 target[(orows / 2) * osize + i] = 0;
59 template<bool small> void fft_aa_plane(double* obuffer, const double* ibuffer, size_t width, size_t height,
60 size_t istride, size_t ostride, size_t owidth, size_t oheight)
62 double oscale = static_cast<double>(width) * height;
63 for(size_t i = 0; i < height; i++)
64 for(size_t j = 0; j < width; j++)
65 input_memory[i * width + j] = ibuffer[i * istride + j];
66 if(small)
67 fftw_execute(forward_plan2);
68 else
69 fftw_execute(forward_plan);
70 copyimage(input_memory, output_memory, width, owidth, height, oheight);
71 if(small)
72 fftw_execute(backward_plan2);
73 else
74 fftw_execute(backward_plan);
76 for(size_t i = 0; i < oheight; i++)
77 for(size_t j = 0; j < owidth; j++)
78 obuffer[i * ostride + j] = std::abs(output_memory[i * owidth + j]) / oscale;
81 template<typename T> inline double integer_to_float(T val)
83 uint64_t v = ((0x3FFULL + 8 * sizeof(T)) << 52) + (static_cast<uint64_t>(val) << (52 - 8 * sizeof(T)));
84 double* v2 = reinterpret_cast<double*>(&v);
85 return *v2;
88 template<typename T> inline T float_to_integer(double val)
90 const uint64_t low = ((0x3FFULL + 8 * sizeof(T)) << 52);
91 const uint64_t high = ((0x400ULL + 8 * sizeof(T)) << 52);
92 uint64_t* v2 = reinterpret_cast<uint64_t*>(&val);
93 uint64_t v = *v2;
94 //Mark braches for underflow/overflow as unlikely.
95 if(__builtin_expect(v < low, 0))
96 return 0;
97 if(__builtin_expect(v >= high, 0))
98 return std::numeric_limits<T>::max();
99 return (v >> (52 - 8 * sizeof(T)));
102 template<typename T, unsigned bpp, unsigned shift, bool small> void fft_aa_plane(T* obuffer, const T* ibuffer,
103 size_t width, size_t height, size_t istride, size_t ostride, size_t owidth, size_t oheight)
105 static std::vector<double> in;
106 static std::vector<double> out;
107 if(in.size() != height * istride) {
108 in.resize(height * istride);
109 out.resize(oheight * ostride);
111 for(size_t i = 0; i < height * istride; i++)
112 in[i] = integer_to_float<T>(ibuffer[bpp * i + shift]);
113 for(size_t i = 0; i < height * istride; i++)
114 in[i] -= (1ULL << (8 * sizeof(T)));
115 fft_aa_plane<small>(&out[0], &in[0], width, height, istride, ostride, owidth, oheight);
116 for(size_t i = 0; i < oheight * ostride; i++)
117 out[i] += (1ULL << (8 * sizeof(T)));
118 for(size_t i = 0; i < oheight * ostride; i++)
119 obuffer[bpp * i + shift] = float_to_integer<T>(out[i]);
122 inline void fft_aa_rgb(uint8_t* obuffer, const uint8_t* ibuffer, size_t width, size_t height, size_t owidth,
123 size_t oheight)
125 fft_aa_plane<uint8_t, 3, 0, false>(obuffer, ibuffer, width, height, width, owidth, owidth, oheight);
126 fft_aa_plane<uint8_t, 3, 1, false>(obuffer, ibuffer, width, height, width, owidth, owidth, oheight);
127 fft_aa_plane<uint8_t, 3, 2, false>(obuffer, ibuffer, width, height, width, owidth, owidth, oheight);
130 template<typename T, size_t w, size_t h> inline void fft_aa_yuv(T* obuffer, const T* ibuffer, size_t width,
131 size_t height, size_t owidth, size_t oheight)
133 size_t ilsize = width * height;
134 size_t icsize = width * height / w / h;
135 size_t olsize = owidth * oheight;
136 size_t ocsize = owidth * oheight / w / h;
137 size_t cwidth = width / w;
138 size_t cheight = height / h;
139 size_t ocwidth = owidth / w;
140 size_t ocheight = oheight / h;
141 if(width % w || height % h || owidth % w || oheight % h)
142 throw std::runtime_error("Size must be multiple of chroma block size");
144 fft_aa_plane<T, 1, 0, false>(obuffer, ibuffer, width, height, width, owidth, owidth, oheight);
145 fft_aa_plane<T, 1, 0, true>(obuffer + olsize, ibuffer + ilsize, cwidth, cheight, cwidth, ocwidth, ocwidth,
146 ocheight);
147 fft_aa_plane<T, 1, 0, true>(obuffer + olsize + ocsize, ibuffer + ilsize + icsize, cwidth, cheight, cwidth,
148 ocwidth, ocwidth, ocheight);
151 typedef void (*fft_scale_fun_t)(void* obuffer, const void* ibuffer, size_t width, size_t height, size_t owidth,
152 size_t oheight);
154 template<typename T, size_t w, size_t h, bool rgb> inline void fft_aa2(void* obuffer, const void* ibuffer,
155 size_t width, size_t height, size_t owidth, size_t oheight)
157 if(rgb)
158 fft_aa_rgb(reinterpret_cast<uint8_t*>(obuffer), reinterpret_cast<const uint8_t*>(ibuffer),
159 width, height, owidth, oheight);
160 else
161 fft_aa_yuv<T, w, h>(reinterpret_cast<T*>(obuffer), reinterpret_cast<const T*>(ibuffer), width, height,
162 owidth, oheight);
165 int main(int argc, char** argv)
167 size_t outwidth = 0;
168 size_t outheight = 0;
169 size_t hscale = 0;
170 size_t vscale = 0;
171 for(int i = 1; i < argc; i++) {
172 std::string arg = argv[i];
173 regex_results r;
174 if(r = regex("--resolution=([1-9][0-9]*)x([1-9][0-9]*)", arg)) {
175 try {
176 outwidth = parse_value<unsigned>(r[1]);
177 outheight = parse_value<unsigned>(r[2]);
178 } catch(std::exception& e) {
179 std::cerr << "fftaa: Bad resolution '" << r[1] << "x" << r[2] << "'" << std::endl;
180 return 1;
182 } else if(r = regex("--scale=([1-9][0-9]*)x([1-9][0-9]*)", arg)) {
183 try {
184 hscale = parse_value<unsigned>(r[1]);
185 vscale = parse_value<unsigned>(r[2]);
186 } catch(std::exception& e) {
187 std::cerr << "fftaa: Bad scale '" << r[1] << "x" << r[2] << "'" << std::endl;
188 return 1;
190 } else if(r = regex("--scale=([1-9][0-9]*)", arg)) {
191 try {
192 hscale = parse_value<unsigned>(r[1]);
193 vscale = parse_value<unsigned>(r[1]);
194 } catch(std::exception& e) {
195 std::cerr << "fftaa: Bad scale '" << r[1] << "'" << std::endl;
196 return 1;
198 } else {
199 std::cerr << "fftaa: Unrecognized option '" << arg << "'" << std::endl;
200 return 1;
203 if((!outwidth || !outheight) && (!hscale || !vscale)) {
204 std::cerr << "fftaa: Resolution needed" << std::endl;
205 return 1;
208 //Open files.
209 FILE* in = stdin;
210 FILE* out = stdout;
211 mark_pipe_as_binary(in);
212 mark_pipe_as_binary(out);
214 //Fixup header.
215 try {
216 struct yuv4mpeg_stream_header strmh(in);
217 size_t inwidth = strmh.width;
218 size_t inheight = strmh.height;
220 if(hscale && vscale) {
221 outwidth = inwidth / hscale;
222 outheight = inheight / vscale;
225 strmh.width = outwidth;
226 strmh.height = outheight;
228 if(inwidth < outwidth || inheight < outheight) {
229 std::cerr << "fftaa: Only scaling to smaller is supported" << std::endl;
230 return 1;
233 fft_scale_fun_t fun;
234 size_t quads;
235 size_t wf = 1;
236 size_t hf = 1;
238 if(strmh.chroma == "rgb") {
239 quads = 12;
240 fun = fft_aa2<uint8_t, 1, 1, true>;
241 } else if(strmh.chroma == "420") {
242 quads = 6;
243 fun = fft_aa2<uint8_t, 2, 2, false>;
244 hf = wf = 2;
245 } else if(strmh.chroma == "420p16") {
246 quads = 12;
247 fun = fft_aa2<uint16_t, 2, 2, false>;
248 hf = wf = 2;
249 } else if(strmh.chroma == "422") {
250 quads = 8;
251 fun = fft_aa2<uint8_t, 2, 1, false>;
252 wf = 2;
253 } else if(strmh.chroma == "422p16") {
254 quads = 16;
255 fun = fft_aa2<uint16_t, 2, 1, false>;
256 wf = 2;
257 } else if(strmh.chroma == "444") {
258 quads = 12;
259 fun = fft_aa2<uint8_t, 1, 1, false>;
260 } else if(strmh.chroma == "444p16") {
261 quads = 24;
262 fun = fft_aa2<uint16_t, 1, 1, false>;
263 } else
264 throw std::runtime_error("Unsupported input chroma type '" + strmh.chroma + "'");
266 std::string _strmh = std::string(strmh);
267 write_or_die(out, _strmh);
269 allocate_ffts(inwidth, inheight, outwidth, outheight, wf, hf);
271 std::vector<char> buffer;
272 std::vector<char> buffer2;
273 size_t insize = inwidth * inheight * quads / 4;
274 size_t outsize = outwidth * outheight * quads / 4;
275 buffer.resize(insize + 128);
276 buffer2.resize(outsize + 128);
277 while(true) {
278 std::string _framh;
279 if(!read_line2(in, _framh))
280 break;
281 read_or_die(in, &buffer[0], insize);
282 fun(&buffer2[0], &buffer[0], inwidth, inheight, outwidth, outheight);
283 write_or_die(out, _framh);
284 write_or_die(out, &buffer2[0], outsize);
287 } catch(std::exception& e) {
288 std::cerr << "fftaa: Error: " << e.what() << std::endl;
289 return 1;
291 return 0;