Add dr prediction test
[aom.git] / test / av1_highbd_iht_test.cc
blob8cadc85e79648fe36e5f614b1a7055ddde33890c
1 /*
2 * Copyright (c) 2016, Alliance for Open Media. All rights reserved
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
12 #include "third_party/googletest/src/googletest/include/gtest/gtest.h"
14 #include "config/av1_rtcd.h"
16 #include "test/acm_random.h"
17 #include "test/clear_system_state.h"
18 #include "test/register_state_check.h"
19 #include "test/util.h"
20 #include "av1/common/enums.h"
21 #include "aom_dsp/aom_dsp_common.h"
22 #include "aom_ports/mem.h"
24 namespace {
26 using ::testing::tuple;
27 using libaom_test::ACMRandom;
29 typedef void (*HbdHtFunc)(const int16_t *input, int32_t *output, int stride,
30 TX_TYPE tx_type, int bd);
32 typedef void (*IHbdHtFunc)(const int32_t *coeff, uint16_t *output, int stride,
33 TX_TYPE tx_type, int bd);
35 // Test parameter argument list:
36 // <transform reference function,
37 // optimized inverse transform function,
38 // inverse transform reference function,
39 // num_coeffs,
40 // tx_type,
41 // bit_depth>
42 typedef tuple<HbdHtFunc, IHbdHtFunc, IHbdHtFunc, int, TX_TYPE, int> IHbdHtParam;
44 class AV1HighbdInvHTNxN : public ::testing::TestWithParam<IHbdHtParam> {
45 public:
46 virtual ~AV1HighbdInvHTNxN() {}
48 virtual void SetUp() {
49 txfm_ref_ = GET_PARAM(0);
50 inv_txfm_ = GET_PARAM(1);
51 inv_txfm_ref_ = GET_PARAM(2);
52 num_coeffs_ = GET_PARAM(3);
53 tx_type_ = GET_PARAM(4);
54 bit_depth_ = GET_PARAM(5);
56 input_ = reinterpret_cast<int16_t *>(
57 aom_memalign(16, sizeof(input_[0]) * num_coeffs_));
59 // Note:
60 // Inverse transform input buffer is 32-byte aligned
61 // Refer to <root>/av1/encoder/context_tree.c, function,
62 // void alloc_mode_context().
63 coeffs_ = reinterpret_cast<int32_t *>(
64 aom_memalign(32, sizeof(coeffs_[0]) * num_coeffs_));
65 output_ = reinterpret_cast<uint16_t *>(
66 aom_memalign(32, sizeof(output_[0]) * num_coeffs_));
67 output_ref_ = reinterpret_cast<uint16_t *>(
68 aom_memalign(32, sizeof(output_ref_[0]) * num_coeffs_));
71 virtual void TearDown() {
72 aom_free(input_);
73 aom_free(coeffs_);
74 aom_free(output_);
75 aom_free(output_ref_);
76 libaom_test::ClearSystemState();
79 protected:
80 void RunBitexactCheck();
82 private:
83 int GetStride() const {
84 if (16 == num_coeffs_) {
85 return 4;
86 } else if (64 == num_coeffs_) {
87 return 8;
88 } else if (256 == num_coeffs_) {
89 return 16;
90 } else if (1024 == num_coeffs_) {
91 return 32;
92 } else if (4096 == num_coeffs_) {
93 return 64;
94 } else {
95 return 0;
99 HbdHtFunc txfm_ref_;
100 IHbdHtFunc inv_txfm_;
101 IHbdHtFunc inv_txfm_ref_;
102 int num_coeffs_;
103 TX_TYPE tx_type_;
104 int bit_depth_;
106 int16_t *input_;
107 int32_t *coeffs_;
108 uint16_t *output_;
109 uint16_t *output_ref_;
112 void AV1HighbdInvHTNxN::RunBitexactCheck() {
113 ACMRandom rnd(ACMRandom::DeterministicSeed());
114 const int stride = GetStride();
115 const int num_tests = 20000;
116 const uint16_t mask = (1 << bit_depth_) - 1;
118 for (int i = 0; i < num_tests; ++i) {
119 for (int j = 0; j < num_coeffs_; ++j) {
120 input_[j] = (rnd.Rand16() & mask) - (rnd.Rand16() & mask);
121 output_ref_[j] = rnd.Rand16() & mask;
122 output_[j] = output_ref_[j];
125 txfm_ref_(input_, coeffs_, stride, tx_type_, bit_depth_);
126 inv_txfm_ref_(coeffs_, output_ref_, stride, tx_type_, bit_depth_);
127 ASM_REGISTER_STATE_CHECK(
128 inv_txfm_(coeffs_, output_, stride, tx_type_, bit_depth_));
130 for (int j = 0; j < num_coeffs_; ++j) {
131 EXPECT_EQ(output_ref_[j], output_[j])
132 << "Not bit-exact result at index: " << j << " At test block: " << i;
137 TEST_P(AV1HighbdInvHTNxN, InvTransResultCheck) { RunBitexactCheck(); }
139 using ::testing::make_tuple;
141 #if HAVE_SSE4_1
142 #define PARAM_LIST_4X4 \
143 &av1_fwd_txfm2d_4x4_c, &av1_inv_txfm2d_add_4x4_sse4_1, \
144 &av1_inv_txfm2d_add_4x4_c, 16
145 #define PARAM_LIST_8X8 \
146 &av1_fwd_txfm2d_8x8_c, &av1_inv_txfm2d_add_8x8_sse4_1, \
147 &av1_inv_txfm2d_add_8x8_c, 64
148 #define PARAM_LIST_16X16 \
149 &av1_fwd_txfm2d_16x16_c, &av1_inv_txfm2d_add_16x16_sse4_1, \
150 &av1_inv_txfm2d_add_16x16_c, 256
151 #define PARAM_LIST_64X64 \
152 &av1_fwd_txfm2d_64x64_c, &av1_inv_txfm2d_add_64x64_sse4_1, \
153 &av1_inv_txfm2d_add_64x64_c, 4096
155 const IHbdHtParam kArrayIhtParam[] = {
156 // 16x16
157 make_tuple(PARAM_LIST_16X16, DCT_DCT, 10),
158 make_tuple(PARAM_LIST_16X16, DCT_DCT, 12),
159 make_tuple(PARAM_LIST_16X16, ADST_DCT, 10),
160 make_tuple(PARAM_LIST_16X16, ADST_DCT, 12),
161 make_tuple(PARAM_LIST_16X16, DCT_ADST, 10),
162 make_tuple(PARAM_LIST_16X16, DCT_ADST, 12),
163 make_tuple(PARAM_LIST_16X16, ADST_ADST, 10),
164 make_tuple(PARAM_LIST_16X16, ADST_ADST, 12),
165 make_tuple(PARAM_LIST_16X16, FLIPADST_DCT, 10),
166 make_tuple(PARAM_LIST_16X16, FLIPADST_DCT, 12),
167 make_tuple(PARAM_LIST_16X16, DCT_FLIPADST, 10),
168 make_tuple(PARAM_LIST_16X16, DCT_FLIPADST, 12),
169 make_tuple(PARAM_LIST_16X16, FLIPADST_FLIPADST, 10),
170 make_tuple(PARAM_LIST_16X16, FLIPADST_FLIPADST, 12),
171 make_tuple(PARAM_LIST_16X16, ADST_FLIPADST, 10),
172 make_tuple(PARAM_LIST_16X16, ADST_FLIPADST, 12),
173 make_tuple(PARAM_LIST_16X16, FLIPADST_ADST, 10),
174 make_tuple(PARAM_LIST_16X16, FLIPADST_ADST, 12),
175 // 8x8
176 make_tuple(PARAM_LIST_8X8, DCT_DCT, 10),
177 make_tuple(PARAM_LIST_8X8, DCT_DCT, 12),
178 make_tuple(PARAM_LIST_8X8, ADST_DCT, 10),
179 make_tuple(PARAM_LIST_8X8, ADST_DCT, 12),
180 make_tuple(PARAM_LIST_8X8, DCT_ADST, 10),
181 make_tuple(PARAM_LIST_8X8, DCT_ADST, 12),
182 make_tuple(PARAM_LIST_8X8, ADST_ADST, 10),
183 make_tuple(PARAM_LIST_8X8, ADST_ADST, 12),
184 make_tuple(PARAM_LIST_8X8, FLIPADST_DCT, 10),
185 make_tuple(PARAM_LIST_8X8, FLIPADST_DCT, 12),
186 make_tuple(PARAM_LIST_8X8, DCT_FLIPADST, 10),
187 make_tuple(PARAM_LIST_8X8, DCT_FLIPADST, 12),
188 make_tuple(PARAM_LIST_8X8, FLIPADST_FLIPADST, 10),
189 make_tuple(PARAM_LIST_8X8, FLIPADST_FLIPADST, 12),
190 make_tuple(PARAM_LIST_8X8, ADST_FLIPADST, 10),
191 make_tuple(PARAM_LIST_8X8, ADST_FLIPADST, 12),
192 make_tuple(PARAM_LIST_8X8, FLIPADST_ADST, 10),
193 make_tuple(PARAM_LIST_8X8, FLIPADST_ADST, 12),
194 // 4x4
195 make_tuple(PARAM_LIST_4X4, DCT_DCT, 10),
196 make_tuple(PARAM_LIST_4X4, DCT_DCT, 12),
197 make_tuple(PARAM_LIST_4X4, ADST_DCT, 10),
198 make_tuple(PARAM_LIST_4X4, ADST_DCT, 12),
199 make_tuple(PARAM_LIST_4X4, DCT_ADST, 10),
200 make_tuple(PARAM_LIST_4X4, DCT_ADST, 12),
201 make_tuple(PARAM_LIST_4X4, ADST_ADST, 10),
202 make_tuple(PARAM_LIST_4X4, ADST_ADST, 12),
203 make_tuple(PARAM_LIST_4X4, FLIPADST_DCT, 10),
204 make_tuple(PARAM_LIST_4X4, FLIPADST_DCT, 12),
205 make_tuple(PARAM_LIST_4X4, DCT_FLIPADST, 10),
206 make_tuple(PARAM_LIST_4X4, DCT_FLIPADST, 12),
207 make_tuple(PARAM_LIST_4X4, FLIPADST_FLIPADST, 10),
208 make_tuple(PARAM_LIST_4X4, FLIPADST_FLIPADST, 12),
209 make_tuple(PARAM_LIST_4X4, ADST_FLIPADST, 10),
210 make_tuple(PARAM_LIST_4X4, ADST_FLIPADST, 12),
211 make_tuple(PARAM_LIST_4X4, FLIPADST_ADST, 10),
212 make_tuple(PARAM_LIST_4X4, FLIPADST_ADST, 12),
213 make_tuple(PARAM_LIST_64X64, DCT_DCT, 10),
214 make_tuple(PARAM_LIST_64X64, DCT_DCT, 12),
217 INSTANTIATE_TEST_CASE_P(SSE4_1, AV1HighbdInvHTNxN,
218 ::testing::ValuesIn(kArrayIhtParam));
219 #endif // HAVE_SSE4_1
221 #if HAVE_AVX2
222 #define PARAM_LIST_32X32 \
223 &av1_fwd_txfm2d_32x32_c, &av1_inv_txfm2d_add_32x32_avx2, \
224 &av1_inv_txfm2d_add_32x32_c, 1024
226 const IHbdHtParam kArrayIhtParam32x32[] = {
227 // 32x32
228 make_tuple(PARAM_LIST_32X32, DCT_DCT, 10),
229 make_tuple(PARAM_LIST_32X32, DCT_DCT, 12),
232 INSTANTIATE_TEST_CASE_P(AVX2, AV1HighbdInvHTNxN,
233 ::testing::ValuesIn(kArrayIhtParam32x32));
235 #endif // HAVE_AVX2
236 } // namespace