Add dr prediction test
[aom.git] / test / fft_test.cc
blob6a3fd0d04fe0d0c116d5e87ced85ac9a51107556
1 #include <math.h>
3 #include <algorithm>
4 #include <complex>
5 #include <vector>
7 #include "aom_dsp/fft_common.h"
8 #include "aom_mem/aom_mem.h"
9 #if ARCH_X86 || ARCH_X86_64
10 #include "aom_ports/x86.h"
11 #endif
12 #include "av1/common/common.h"
13 #include "config/aom_dsp_rtcd.h"
14 #include "test/acm_random.h"
15 #include "third_party/googletest/src/googletest/include/gtest/gtest.h"
17 namespace {
19 typedef void (*tform_fun_t)(const float *input, float *temp, float *output);
21 // Simple 1D FFT implementation
22 template <typename InputType>
23 void fft(const InputType *data, std::complex<float> *result, int n) {
24 if (n == 1) {
25 result[0] = data[0];
26 return;
28 std::vector<InputType> temp(n);
29 for (int k = 0; k < n / 2; ++k) {
30 temp[k] = data[2 * k];
31 temp[n / 2 + k] = data[2 * k + 1];
33 fft(&temp[0], result, n / 2);
34 fft(&temp[n / 2], result + n / 2, n / 2);
35 for (int k = 0; k < n / 2; ++k) {
36 std::complex<float> w = std::complex<float>((float)cos(2. * PI * k / n),
37 (float)-sin(2. * PI * k / n));
38 std::complex<float> a = result[k];
39 std::complex<float> b = result[n / 2 + k];
40 result[k] = a + w * b;
41 result[n / 2 + k] = a - w * b;
45 void transpose(std::vector<std::complex<float> > *data, int n) {
46 for (int y = 0; y < n; ++y) {
47 for (int x = y + 1; x < n; ++x) {
48 std::swap((*data)[y * n + x], (*data)[x * n + y]);
53 // Simple 2D FFT implementation
54 template <class InputType>
55 std::vector<std::complex<float> > fft2d(const InputType *input, int n) {
56 std::vector<std::complex<float> > rowfft(n * n);
57 std::vector<std::complex<float> > result(n * n);
58 for (int y = 0; y < n; ++y) {
59 fft(input + y * n, &rowfft[y * n], n);
61 transpose(&rowfft, n);
62 for (int y = 0; y < n; ++y) {
63 fft(&rowfft[y * n], &result[y * n], n);
65 transpose(&result, n);
66 return result;
69 struct FFTTestArg {
70 int n;
71 void (*fft)(const float *input, float *temp, float *output);
72 int flag;
73 FFTTestArg(int n_in, tform_fun_t fft_in, int flag_in)
74 : n(n_in), fft(fft_in), flag(flag_in) {}
77 class FFT2DTest : public ::testing::TestWithParam<FFTTestArg> {
78 protected:
79 void SetUp() {
80 int n = GetParam().n;
81 input_ = (float *)aom_memalign(32, sizeof(*input_) * n * n);
82 temp_ = (float *)aom_memalign(32, sizeof(*temp_) * n * n);
83 output_ = (float *)aom_memalign(32, sizeof(*output_) * n * n * 2);
84 memset(input_, 0, sizeof(*input_) * n * n);
85 memset(temp_, 0, sizeof(*temp_) * n * n);
86 memset(output_, 0, sizeof(*output_) * n * n * 2);
87 #if ARCH_X86 || ARCH_X86_64
88 disabled_ = GetParam().flag != 0 && !(x86_simd_caps() & GetParam().flag);
89 #else
90 disabled_ = GetParam().flag != 0;
91 #endif
93 void TearDown() {
94 aom_free(input_);
95 aom_free(temp_);
96 aom_free(output_);
98 int disabled_;
99 float *input_;
100 float *temp_;
101 float *output_;
104 TEST_P(FFT2DTest, Correct) {
105 if (disabled_) return;
107 int n = GetParam().n;
108 for (int i = 0; i < n * n; ++i) {
109 input_[i] = 1;
110 std::vector<std::complex<float> > expected = fft2d<float>(&input_[0], n);
111 GetParam().fft(&input_[0], &temp_[0], &output_[0]);
112 for (int y = 0; y < n; ++y) {
113 for (int x = 0; x < (n / 2) + 1; ++x) {
114 EXPECT_NEAR(expected[y * n + x].real(), output_[2 * (y * n + x)], 1e-5);
115 EXPECT_NEAR(expected[y * n + x].imag(), output_[2 * (y * n + x) + 1],
116 1e-5);
119 input_[i] = 0;
123 TEST_P(FFT2DTest, Benchmark) {
124 if (disabled_) return;
126 int n = GetParam().n;
127 float sum = 0;
128 for (int i = 0; i < 1000 * (64 - n); ++i) {
129 input_[i % (n * n)] = 1;
130 GetParam().fft(&input_[0], &temp_[0], &output_[0]);
131 sum += output_[0];
132 input_[i % (n * n)] = 0;
136 INSTANTIATE_TEST_CASE_P(
137 FFT2DTestC, FFT2DTest,
138 ::testing::Values(FFTTestArg(2, aom_fft2x2_float_c, 0),
139 FFTTestArg(4, aom_fft4x4_float_c, 0),
140 FFTTestArg(8, aom_fft8x8_float_c, 0),
141 FFTTestArg(16, aom_fft16x16_float_c, 0),
142 FFTTestArg(32, aom_fft32x32_float_c, 0)));
143 #if ARCH_X86 || ARCH_X86_64
144 INSTANTIATE_TEST_CASE_P(
145 FFT2DTestSSE2, FFT2DTest,
146 ::testing::Values(FFTTestArg(4, aom_fft4x4_float_sse2, HAS_SSE2),
147 FFTTestArg(8, aom_fft8x8_float_sse2, HAS_SSE2),
148 FFTTestArg(16, aom_fft16x16_float_sse2, HAS_SSE2),
149 FFTTestArg(32, aom_fft32x32_float_sse2, HAS_SSE2)));
151 INSTANTIATE_TEST_CASE_P(
152 FFT2DTestAVX2, FFT2DTest,
153 ::testing::Values(FFTTestArg(8, aom_fft8x8_float_avx2, HAS_AVX2),
154 FFTTestArg(16, aom_fft16x16_float_avx2, HAS_AVX2),
155 FFTTestArg(32, aom_fft32x32_float_avx2, HAS_AVX2)));
156 #endif
158 struct IFFTTestArg {
159 int n;
160 tform_fun_t ifft;
161 int flag;
162 IFFTTestArg(int n_in, tform_fun_t ifft_in, int flag_in)
163 : n(n_in), ifft(ifft_in), flag(flag_in) {}
166 class IFFT2DTest : public ::testing::TestWithParam<IFFTTestArg> {
167 protected:
168 void SetUp() {
169 int n = GetParam().n;
170 input_ = (float *)aom_memalign(32, sizeof(*input_) * n * n * 2);
171 temp_ = (float *)aom_memalign(32, sizeof(*temp_) * n * n * 2);
172 output_ = (float *)aom_memalign(32, sizeof(*output_) * n * n);
173 memset(input_, 0, sizeof(*input_) * n * n * 2);
174 memset(temp_, 0, sizeof(*temp_) * n * n * 2);
175 memset(output_, 0, sizeof(*output_) * n * n);
176 #if ARCH_X86 || ARCH_X86_64
177 disabled_ = GetParam().flag != 0 && !(x86_simd_caps() & GetParam().flag);
178 #else
179 disabled_ = GetParam().flag != 0;
180 #endif
182 void TearDown() {
183 aom_free(input_);
184 aom_free(temp_);
185 aom_free(output_);
187 int disabled_;
188 float *input_;
189 float *temp_;
190 float *output_;
193 TEST_P(IFFT2DTest, Correctness) {
194 if (disabled_) return;
195 int n = GetParam().n;
196 ASSERT_GE(n, 2);
197 std::vector<float> expected(n * n);
198 std::vector<float> actual(n * n);
199 // Do forward transform then invert to make sure we get back expected
200 for (int y = 0; y < n; ++y) {
201 for (int x = 0; x < n; ++x) {
202 expected[y * n + x] = 1;
203 std::vector<std::complex<float> > input_c = fft2d(&expected[0], n);
204 for (int i = 0; i < n * n; ++i) {
205 input_[2 * i + 0] = input_c[i].real();
206 input_[2 * i + 1] = input_c[i].imag();
208 GetParam().ifft(&input_[0], &temp_[0], &output_[0]);
210 for (int yy = 0; yy < n; ++yy) {
211 for (int xx = 0; xx < n; ++xx) {
212 EXPECT_NEAR(expected[yy * n + xx], output_[yy * n + xx] / (n * n),
213 1e-5);
216 expected[y * n + x] = 0;
221 TEST_P(IFFT2DTest, Benchmark) {
222 if (disabled_) return;
223 int n = GetParam().n;
224 float sum = 0;
225 for (int i = 0; i < 1000 * (64 - n); ++i) {
226 input_[i % (n * n)] = 1;
227 GetParam().ifft(&input_[0], &temp_[0], &output_[0]);
228 sum += output_[0];
229 input_[i % (n * n)] = 0;
232 INSTANTIATE_TEST_CASE_P(
233 IFFT2DTestC, IFFT2DTest,
234 ::testing::Values(IFFTTestArg(2, aom_ifft2x2_float_c, 0),
235 IFFTTestArg(4, aom_ifft4x4_float_c, 0),
236 IFFTTestArg(8, aom_ifft8x8_float_c, 0),
237 IFFTTestArg(16, aom_ifft16x16_float_c, 0),
238 IFFTTestArg(32, aom_ifft32x32_float_c, 0)));
239 #if ARCH_X86 || ARCH_X86_64
240 INSTANTIATE_TEST_CASE_P(
241 IFFT2DTestSSE2, IFFT2DTest,
242 ::testing::Values(IFFTTestArg(4, aom_ifft4x4_float_sse2, HAS_SSE2),
243 IFFTTestArg(8, aom_ifft8x8_float_sse2, HAS_SSE2),
244 IFFTTestArg(16, aom_ifft16x16_float_sse2, HAS_SSE2),
245 IFFTTestArg(32, aom_ifft32x32_float_sse2, HAS_SSE2)));
247 INSTANTIATE_TEST_CASE_P(
248 IFFT2DTestAVX2, IFFT2DTest,
249 ::testing::Values(IFFTTestArg(8, aom_ifft8x8_float_avx2, HAS_AVX2),
250 IFFTTestArg(16, aom_ifft16x16_float_avx2, HAS_AVX2),
251 IFFTTestArg(32, aom_ifft32x32_float_avx2, HAS_AVX2)));
252 #endif
253 } // namespace