Backed out 3 changesets (bug 1790375) for causing wd failures on fetch_error.py....
[gecko.git] / third_party / aom / test / convolve_round_test.cc
blob2f801e7d4639cd38f1d59b0ef94d1f6802093277
1 /*
2 * Copyright (c) 2017, Alliance for Open Media. All rights reserved
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
12 #include <assert.h>
14 #include "config/av1_rtcd.h"
16 #include "aom/aom_integer.h"
17 #include "aom_ports/aom_timer.h"
18 #include "test/acm_random.h"
19 #include "test/clear_system_state.h"
20 #include "test/register_state_check.h"
21 #include "test/util.h"
22 #include "third_party/googletest/src/googletest/include/gtest/gtest.h"
24 using libaom_test::ACMRandom;
26 namespace {
27 #define CONVOLVE_ROUNDING_PARAM \
28 const int32_t *src, int src_stride, uint8_t *dst, int dst_stride, int w, \
29 int h, int bits
31 typedef void (*ConvolveRoundFunc)(CONVOLVE_ROUNDING_PARAM);
33 typedef void (*ConvolveRoundFuncHbd)(CONVOLVE_ROUNDING_PARAM, int bd);
35 template <ConvolveRoundFuncHbd fn>
36 void highbd_convolve_rounding_8(CONVOLVE_ROUNDING_PARAM) {
37 const int bd = 8;
38 fn(src, src_stride, dst, dst_stride, w, h, bits, bd);
41 template <ConvolveRoundFuncHbd fn>
42 void highbd_convolve_rounding_10(CONVOLVE_ROUNDING_PARAM) {
43 const int bd = 10;
44 fn(src, src_stride, dst, dst_stride, w, h, bits, bd);
47 template <ConvolveRoundFuncHbd fn>
48 void highbd_convolve_rounding_12(CONVOLVE_ROUNDING_PARAM) {
49 const int bd = 12;
50 fn(src, src_stride, dst, dst_stride, w, h, bits, bd);
53 typedef enum { LOWBITDEPTH_TEST, HIGHBITDEPTH_TEST } DataPathType;
55 using ::testing::tuple;
57 typedef tuple<ConvolveRoundFunc, ConvolveRoundFunc, DataPathType>
58 ConvolveRoundParam;
60 const int kTestNum = 5000;
62 class ConvolveRoundTest : public ::testing::TestWithParam<ConvolveRoundParam> {
63 protected:
64 ConvolveRoundTest()
65 : func_ref_(GET_PARAM(0)), func_(GET_PARAM(1)), data_path_(GET_PARAM(2)) {
67 virtual ~ConvolveRoundTest() {}
69 virtual void SetUp() {
70 const size_t block_size = 128 * 128;
71 src_ = reinterpret_cast<int32_t *>(
72 aom_memalign(16, block_size * sizeof(*src_)));
73 dst_ref_ = reinterpret_cast<uint16_t *>(
74 aom_memalign(16, block_size * sizeof(*dst_ref_)));
75 dst_ = reinterpret_cast<uint16_t *>(
76 aom_memalign(16, block_size * sizeof(*dst_)));
79 virtual void TearDown() {
80 aom_free(src_);
81 aom_free(dst_ref_);
82 aom_free(dst_);
85 void ConvolveRoundingRun() {
86 int test_num = 0;
87 const int src_stride = 128;
88 const int dst_stride = 128;
89 int bits = 13;
90 uint8_t *dst = 0;
91 uint8_t *dst_ref = 0;
93 if (data_path_ == LOWBITDEPTH_TEST) {
94 dst = reinterpret_cast<uint8_t *>(dst_);
95 dst_ref = reinterpret_cast<uint8_t *>(dst_ref_);
96 } else if (data_path_ == HIGHBITDEPTH_TEST) {
97 dst = CONVERT_TO_BYTEPTR(dst_);
98 dst_ref = CONVERT_TO_BYTEPTR(dst_ref_);
99 } else {
100 assert(0);
103 while (test_num < kTestNum) {
104 int block_size = test_num % BLOCK_SIZES_ALL;
105 int w = block_size_wide[block_size];
106 int h = block_size_high[block_size];
108 if (test_num % 2 == 0)
109 bits -= 1;
110 else
111 bits += 1;
113 GenerateBufferWithRandom(src_, src_stride, bits, w, h);
115 func_ref_(src_, src_stride, dst_ref, dst_stride, w, h, bits);
116 ASM_REGISTER_STATE_CHECK(
117 func_(src_, src_stride, dst, dst_stride, w, h, bits));
119 if (data_path_ == LOWBITDEPTH_TEST) {
120 for (int r = 0; r < h; ++r) {
121 for (int c = 0; c < w; ++c) {
122 ASSERT_EQ(dst_ref[r * dst_stride + c], dst[r * dst_stride + c])
123 << "Mismatch at r: " << r << " c: " << c << " w: " << w
124 << " h: " << h << " test: " << test_num;
127 } else {
128 for (int r = 0; r < h; ++r) {
129 for (int c = 0; c < w; ++c) {
130 ASSERT_EQ(dst_ref_[r * dst_stride + c], dst_[r * dst_stride + c])
131 << "Mismatch at r: " << r << " c: " << c << " w: " << w
132 << " h: " << h << " test: " << test_num;
137 test_num++;
141 void GenerateBufferWithRandom(int32_t *src, int src_stride, int bits, int w,
142 int h) {
143 int32_t number;
144 for (int r = 0; r < h; ++r) {
145 for (int c = 0; c < w; ++c) {
146 number = static_cast<int32_t>(rand_.Rand31());
147 number %= 1 << (bits + 9);
148 src[r * src_stride + c] = number;
153 ACMRandom rand_;
154 int32_t *src_;
155 uint16_t *dst_ref_;
156 uint16_t *dst_;
158 ConvolveRoundFunc func_ref_;
159 ConvolveRoundFunc func_;
160 DataPathType data_path_;
163 TEST_P(ConvolveRoundTest, BitExactCheck) { ConvolveRoundingRun(); }
165 using ::testing::make_tuple;
166 #if HAVE_AVX2
167 const ConvolveRoundParam kConvRndParamArray[] = {
168 make_tuple(&av1_convolve_rounding_c, &av1_convolve_rounding_avx2,
169 LOWBITDEPTH_TEST),
170 make_tuple(&highbd_convolve_rounding_8<av1_highbd_convolve_rounding_c>,
171 &highbd_convolve_rounding_8<av1_highbd_convolve_rounding_avx2>,
172 HIGHBITDEPTH_TEST),
173 make_tuple(&highbd_convolve_rounding_10<av1_highbd_convolve_rounding_c>,
174 &highbd_convolve_rounding_10<av1_highbd_convolve_rounding_avx2>,
175 HIGHBITDEPTH_TEST),
176 make_tuple(&highbd_convolve_rounding_12<av1_highbd_convolve_rounding_c>,
177 &highbd_convolve_rounding_12<av1_highbd_convolve_rounding_avx2>,
178 HIGHBITDEPTH_TEST)
180 INSTANTIATE_TEST_CASE_P(AVX2, ConvolveRoundTest,
181 ::testing::ValuesIn(kConvRndParamArray));
182 #endif // HAVE_AVX2
183 } // namespace