2 * Copyright (c) 2019, Alliance for Open Media. All rights reserved
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
16 #include "third_party/googletest/src/googletest/include/gtest/gtest.h"
18 #include "config/av1_rtcd.h"
20 #include "aom_ports/aom_timer.h"
21 #include "av1/encoder/cnn.h"
22 #include "av1/encoder/partition_cnn_weights.h"
23 #include "test/acm_random.h"
24 #include "test/function_equivalence_test.h"
25 #include "test/util.h"
27 #define SQR(x) ((x) * (x))
29 // Best possible pixelwise guaranteed precision given each float has at most
30 // 3 specified decimals.
31 #define PIXELWISE_FLOAT_TOL 1E-2
33 #define MSE_FLOAT_TOL 1E-6
36 // CNN convolve pixelwise error threshold for functional equivalence.
37 #define CNN_CONVOLVE_PIXELWISE_FLOAT_TOL 1E-3f
41 class CNNTest
: public ::testing::Test
{
43 static void RunCNNTest(int image_width
, int image_height
, const float *input
,
44 const float *expected
, const CNN_CONFIG
*cnn_config
,
45 int in_stride
, CNN_THREAD_DATA
*thread_data
,
47 int out_width
, out_height
, out_channels
;
48 av1_find_cnn_output_size(image_width
, image_height
, cnn_config
, &out_width
,
49 &out_height
, &out_channels
);
51 const int out_size
= out_width
* out_height
;
52 const int out_stride
= out_width
;
55 (float *)aom_malloc(sizeof(*output_
) * out_size
* out_channels
);
56 ASSERT_NE(output_
, nullptr);
57 float *output
[CNN_MAX_CHANNELS
] = { nullptr };
58 for (int channel
= 0; channel
< out_channels
; ++channel
) {
59 output
[channel
] = output_
+ (channel
* out_size
);
61 const int num_outputs
= 1;
62 const int output_chs
[1] = { out_channels
};
63 const int output_strides
[1] = { out_stride
};
64 CNN_MULTI_OUT output_struct
= { num_outputs
, output_chs
, output_strides
,
67 RunMultiOutCNNTest(&input
, image_width
, image_height
, in_stride
, cnn_config
,
68 thread_data
, &output_struct
, &expected
, tolerance
);
73 static void RunMultiOutCNNTest(const float **input
, int image_width
,
74 int image_height
, int in_stride
,
75 const CNN_CONFIG
*cnn_config
,
76 CNN_THREAD_DATA
*thread_data
,
77 CNN_MULTI_OUT
*output
, const float **expected
,
79 const int num_outputs
= output
->num_outputs
;
80 const int *output_chs
= output
->output_channels
;
82 int *out_widths
= (int *)aom_calloc(num_outputs
, sizeof(*out_widths
));
83 int *out_heights
= (int *)aom_calloc(num_outputs
, sizeof(*out_heights
));
84 int *not_used
= (int *)aom_calloc(num_outputs
, sizeof(*not_used
));
85 ASSERT_NE(out_widths
, nullptr);
86 ASSERT_NE(out_heights
, nullptr);
87 ASSERT_NE(not_used
, nullptr);
89 av1_find_cnn_output_size(image_width
, image_height
, cnn_config
, out_widths
,
90 out_heights
, not_used
);
91 ASSERT_TRUE(av1_cnn_predict(input
, image_width
, image_height
, in_stride
,
92 cnn_config
, thread_data
, output
));
94 int channel_offset
= 0;
95 for (int output_idx
= 0; output_idx
< num_outputs
; output_idx
++) {
96 const float *expected_out
= expected
[output_idx
];
97 const int curr_output_chs
= output_chs
[output_idx
];
98 const int out_size
= out_widths
[output_idx
] * out_heights
[output_idx
];
101 int expected_ite
= 0;
102 for (int channel
= 0; channel
< curr_output_chs
; ++channel
) {
103 const float *buf_out
= output
->output_buffer
[channel_offset
];
105 for (int i
= 0; i
< out_size
; ++i
) {
106 EXPECT_NEAR(expected_out
[expected_ite
], buf_out
[i
],
108 << " output " << output_idx
<< " channel " << channel
<< " pixel "
109 << expected_ite
% out_size
<< ": " << expected_out
[expected_ite
]
110 << "/" << buf_out
[i
] << std::endl
;
111 mse
+= SQR(expected_out
[expected_ite
] - buf_out
[i
]);
117 mse
/= (out_size
* curr_output_chs
);
118 EXPECT_LE(mse
, tolerance
) << " output " << output_idx
<< std::endl
;
121 aom_free(out_widths
);
122 aom_free(out_heights
);
126 static void AssignLayerWeightsBiases(CNN_CONFIG
*cnn_config
, float *weights
,
128 size_t weight_offset
= 0;
129 size_t bias_offset
= 0;
130 for (int layer
= 0; layer
< cnn_config
->num_layers
; ++layer
) {
131 CNN_LAYER_CONFIG
*layer_config
= &cnn_config
->layer_config
[layer
];
132 layer_config
->weights
= weights
+ weight_offset
;
133 layer_config
->bias
= bias
+ bias_offset
;
134 weight_offset
+= layer_config
->filter_width
*
135 layer_config
->filter_height
* layer_config
->in_channels
*
136 layer_config
->out_channels
;
137 bias_offset
+= layer_config
->out_channels
;
139 ASSERT_NE(layer_config
->weights
, nullptr);
140 ASSERT_NE(layer_config
->bias
, nullptr);
147 TEST_F(CNNTest
, TestMultilayerConvolution
) {
148 int image_height
= 16;
149 int image_width
= 16;
150 int filter_height
= 5;
151 int filter_width
= 4;
154 -3, 1, -3, 2, -2, -2, 2, -2, 1, -2, -3, 1, 2, 2, 2, -2, 0, 1, -1,
155 -3, -1, -1, 1, 0, -3, 1, 0, -1, 1, 0, 0, -3, -3, -3, 0, 2, 1, -1,
156 2, 0, 1, -3, -1, 2, 2, 1, -2, 0, -1, 0, -2, -2, -1, 1, 0, 0, 0,
157 -2, -2, -2, 1, 1, -2, 1, 1, -2, -2, 1, -2, -1, -2, -3, 2, -3, -1, 1,
158 0, -2, -2, -2, 1, -2, -2, -1, -1, 2, 2, 2, -1, 1, -3, -3, 0, 2, 0,
159 2, 1, -3, -3, 1, 2, 2, 1, -2, -3, 0, -3, 0, -3, -2, 0, 1, 1, 0,
160 -3, 2, -1, 2, 1, 0, 1, -2, 1, -1, -1, 2, 0, -2, -3, 1, 1, -2, -1,
161 -3, -3, -1, 0, -3, -2, 0, 0, 1, 0, -3, -2, -1, 1, 0, 2, 1, 0, -3,
162 -2, -3, -3, -1, 0, -2, 2, -1, -3, 0, -1, -1, 2, 0, -3, -2, -1, 0, 0,
163 1, -2, 1, 2, 1, 2, 2, -3, 2, -1, 0, 0, -1, 0, 2, 2, -1, 2, -2,
164 1, 1, -3, -3, 1, -1, -1, -2, 2, -2, -2, 2, -1, -3, 2, -3, 1, -1, -1,
165 -3, 1, -1, 1, 0, -3, -3, 1, -3, -3, 0, 2, 2, -2, -1, 2, 0, 2, 1,
166 -1, -3, 0, 0, -1, -1, 1, 0, 2, 0, -3, 2, 1, 0, 1, -3, 2, -3, -3,
167 -1, -3, -3, 2, 0, 2, -2, 1, -1,
171 -2, 2, -2, 2, -1, -3, 2, 2, 0, 0, -3, -1, -2, -3, 1, -1, 0, 0, 0,
172 2, -2, 2, -2, -3, 1, 1, 1, -3, -1, 0, 1, 2, -2, 0, -1, -3, -1, -2,
173 2, -3, -3, 1, -2, -3, 0, 2, 1, -3, -3, -1, -3, -2, -1, -3, -1, -3, -2,
174 -1, -3, -1, -2, -2, -3, 2, 0, -3, 0, -3, -3, 1, -3, -1, 0, -1, 1, 1,
175 -1, 1, -2, 0, 2, 0, -3, 1, -1, -1, 2, 0, 1, -3, -3, 1, 2, -3, -3,
176 1, -3, 2, 0, -3, 1, 2, 2, -2, -1, -2, 1, 1, 0, -2, -2, 1, 2, -1,
177 -3, 1, -2, 2, -3, -2, -3, 2, 1, 0, -2, 0, 1, -3, 2, -2, -2, 0, 2,
178 -3, 2, 0, 0, 1, -2, 1, 1, -2, -1, -2, 1, -2, 0, -2, -2, 0, -1, -1,
179 -3, -3, -3, 1, -3, -2, 2, -1, 2, 0, 2, -2, 2, -2, 1, -3, -3, -1, 0,
180 2, 2, 1, -1, -3, -1, -3, 2, 1, -2, 0, -3, -1, -3, -1, 2, 1, 0, 2,
181 -1, 1, 0, 1, 2, -1, -2, 2, 1, -3, -1, -3, 0, 1, -2, 0, -2, -3, 0,
182 -2, 2, 2, 0, 0, 2, -3, 2, -3, -2, 1, 2, -3, -3, -1, -3, 0, -3, -3,
183 -2, -2, -2, 0, 0, 1, 0, 0, -1, 0, 0, -3, 0, -3, -1, -2, 1, -2, -1,
184 2, -2, 0, 0, 1, 0, -2, -1, 0, -3, 1, 0, -1, -3, 1, -1, 1, -1, -3,
185 1, 0, 1, 1, -1, 2, 2, 0, 0, 1, -3, 2, -2, -2, -3, -2, -1, -2, 2,
186 0, 2, -2, -3, -1, -3, 2, 2, -1, 2, 2, -1, 0, -3, 1,
190 1, -1, 0, 1, 1, 1, -2,
193 float expected_same
[] = {
194 -1125, 2926, 6406, 631, -1244, 97, -1454, 2526, 1065, 3292, 3464,
195 2553, -330, 532, 1038, 1182, -402, 3758, 3392, 9854, 4365, 1408,
196 4736, 3134, 3838, 2409, 3221, 4350, 6750, 4045, 815, 1188, 2959,
197 9802, 9590, 4572, 5740, 4253, 1701, 7974, 7012, 6854, 7093, 3907,
198 4539, 3886, 4267, 3505, 465, 7824, 9219, 10026, 7968, 957, 2295,
199 5594, 10811, 9641, 5950, 10043, 8783, 3132, 1421, 1110, 4108, 13929,
200 10660, -84, -61, 3932, -180, 6811, 13393, 15147, 15640, 9337, 6961,
201 3808, 1604, 1398, 1047, 6739, 10144, 6517, 4698, 2678, 7389, 2595,
202 5248, 12075, 11272, 13951, 8820, 1090, 2199, 2206, 2788, 12116, 6683,
203 2612, -291, 3183, 9414, 12316, 14524, 12333, 13208, 7832, 4664, 4657,
204 3534, 1298, -666, 4250, 7707, 9103, 5760, 688, 9571, 15782, 14203,
205 14878, 17339, 14684, 8690, 5671, 875, 1429, 1531, 6173, 2984, 5558,
206 2996, 7928, 6733, 16117, 15262, 12757, 7980, 3923, 4795, 5973, 2051,
207 455, -1922, 1816, 5906, 3321, 10908, 10910, 7377, 12204, 12809, 11195,
208 7451, 6666, 74, -1645, -35, -391, 3813, 7324, 892, 1656, 6095,
209 12193, 14648, 12156, 14663, 10251, 10325, 7821, 3925, 323, 697, 442,
210 1324, 4669, 7002, 5485, 5171, 5086, 10582, 11053, 9709, 11353, 8543,
211 5256, 2873, 235, -628, 1496, 1878, -867, 3420, 6865, 5937, 10182,
212 13277, 10069, 10789, 5998, 624, -2082, 4417, 1258, -1080, -819, -1430,
213 1033, 5220, 6335, 8471, 8980, 11908, 14430, 12584, 8404, 1576, -803,
214 985, 1481, 1367, -193, 873, 3684, 2288, 6676, 9477, 11155, 9602,
215 9707, 10507, 4739, 3174, -575, -178, 3002, 1710, 423, -477, 554,
216 3088, 2029, 5113, 5000, 3771, 6090, 5365, 1185, 2855, 399, -312,
220 float expected_replicate
[] = {
221 13768, 13528, 12999, 6906, 4618, 4043, 2611, 9955, 6685, 4776, 2753,
222 1036, 3063, 4544, 5183, 7349, 12451, 12501, 9131, 12753, 8908, 4058,
223 6299, 7542, 7115, 3307, 3360, 3543, 9754, 7808, 5991, 9019, 14320,
224 14919, 12492, 6871, 7373, 3336, 2085, 10604, 9377, 6882, 5009, 3103,
225 6220, 6278, 7588, 10196, 11045, 11563, 11842, 11911, 8279, 2030, 1858,
226 6368, 12123, 9909, 6347, 10345, 9365, 4038, 1673, 3051, 16492, 16649,
227 12276, 408, -301, 4122, -654, 7864, 14038, 15279, 15315, 9744, 8243,
228 5298, 746, 380, 9824, 9124, 10895, 6640, 4712, 2669, 6980, 2759,
229 5385, 12345, 11336, 13129, 8600, 2370, 3682, 5219, 12407, 13123, 6784,
230 2612, -291, 3183, 9414, 12316, 14524, 12333, 13397, 7543, 3916, 4153,
231 4477, 4314, 7983, 8418, 9163, 9103, 5760, 688, 9571, 15782, 14203,
232 14878, 17718, 14570, 7940, 6642, 5094, 7133, 9964, 10219, 3224, 5558,
233 2996, 7928, 6733, 16117, 15262, 12757, 7958, 4401, 5187, 5476, 5529,
234 6055, 2206, 3909, 6015, 3321, 10908, 10910, 7377, 12204, 12809, 11195,
235 6967, 6840, 481, -1600, 274, 1, 10373, 8514, 1123, 2117, 6758,
236 12736, 16223, 13585, 15988, 11771, 10600, 7918, 4156, 2840, 3111, 3287,
237 6359, 7652, 8813, 6530, 6967, 7789, 13671, 13990, 13247, 13241, 9836,
238 5251, 3024, 2313, 1834, 4187, 2637, -1312, 2139, 7378, 7665, 11933,
239 15591, 15314, 15678, 9531, 2820, -1516, 3400, 1314, 22, 363, -2896,
240 -898, 5906, 7308, 10650, 12975, 16978, 20370, 18817, 12381, 4118, -861,
241 -137, 236, 1802, 1632, -350, 2334, 3400, 8680, 14064, 18216, 18675,
242 21765, 22871, 11491, 4937, -1555, -11, 1669, 2392, 3265, -5254, -217,
243 5001, 8063, 13444, 18884, 19706, 22794, 21064, 9545, 6689, -7, 289,
247 float expected_valid
[] = {
248 2612, -291, 3183, 9414, 12316, 14524, 12333, 9103, 5760, 688,
249 9571, 15782, 14203, 14878, 5558, 2996, 7928, 6733, 16117, 15262,
250 12757, 3321, 10908, 10910, 7377, 12204, 12809, 11195,
253 CNN_CONFIG cnn_config
= { 3,
321 // Weights and biases need to be specified separately because
323 AssignLayerWeightsBiases(&cnn_config
, weights
, bias
);
325 CNN_THREAD_DATA thread_data
= { 1, nullptr };
327 RunCNNTest(image_width
, image_height
, input
, expected_same
, &cnn_config
,
328 image_width
, &thread_data
, MSE_INT_TOL
);
330 for (int i
= 0; i
< cnn_config
.num_layers
; ++i
) {
331 cnn_config
.layer_config
[i
].pad
= PADDING_SAME_REPLICATE
;
334 RunCNNTest(image_width
, image_height
, input
, expected_replicate
, &cnn_config
,
335 image_width
, &thread_data
, MSE_INT_TOL
);
337 for (int i
= 0; i
< cnn_config
.num_layers
; ++i
) {
338 cnn_config
.layer_config
[i
].pad
= PADDING_VALID
;
341 RunCNNTest(image_width
, image_height
, input
, expected_valid
, &cnn_config
,
342 image_width
, &thread_data
, MSE_INT_TOL
);
345 TEST_F(CNNTest
, TestRELUSingleLayer
) {
347 int image_height
= 8;
348 int filter_height
= 5;
349 int filter_width
= 4;
351 0, -2, -3, 1, -1, 2, -2, 1, -3, -1, 0, 1, -2, -3, -2, -2,
352 1, -3, 2, -3, -1, -1, 2, 0, -2, -3, 0, -2, -3, 1, -1, -1,
353 2, -2, 0, -2, -3, -3, 1, 1, -1, 1, 0, 1, -3, 0, 2, 2,
354 0, -3, 1, -3, 2, -2, 1, -1, -1, -2, -3, -2, -1, -3, -2, -1,
356 float expected_same
[] = {
357 9, 0, 1, 1, 0, 3, 0, 19, 0, 12, 10, 0, 0, 0, 5, 0,
358 0, 18, 21, 7, 19, 4, 3, 0, 0, 9, 16, 0, 11, 16, 0, 11,
359 12, 2, 0, 11, 0, 16, 6, 0, 8, 22, 13, 10, 12, 0, 0, 0,
360 0, 1, 2, 12, 29, 6, 10, 0, 13, 0, 0, 5, 8, 10, 0, 0,
362 float expected_replicate
[] = {
363 18, 17, 12, 2, 0, 0, 5, 11, 0, 17, 22, 6, 0, 0, 17, 0,
364 0, 18, 21, 7, 19, 4, 3, 5, 3, 9, 16, 0, 11, 16, 0, 3,
365 3, 2, 0, 11, 0, 16, 6, 0, 17, 22, 13, 10, 12, 0, 0, 0,
366 0, 4, 1, 10, 30, 7, 10, 0, 23, 8, 0, 13, 15, 19, 8, 10,
368 float expected_valid
[] = {
369 18, 21, 7, 19, 4, 9, 16, 0, 11, 16, 2, 0, 11, 0, 16, 22, 13, 10, 12, 0,
372 -2, -3, 1, 2, 2, -2, -3, 0, -3, 2, 2, -3, -3, -2, 0, 1, 2, 0, -1, -1,
374 float bias
[] = { -3 };
376 CNN_CONFIG cnn_config
= { 1,
402 CNN_THREAD_DATA thread_data
= { 1, nullptr };
404 RunCNNTest(image_width
, image_height
, input
, expected_same
, &cnn_config
,
405 image_width
, &thread_data
, MSE_INT_TOL
);
407 cnn_config
.layer_config
[0].pad
= PADDING_SAME_REPLICATE
;
409 RunCNNTest(image_width
, image_height
, input
, expected_replicate
, &cnn_config
,
410 image_width
, &thread_data
, MSE_INT_TOL
);
412 cnn_config
.layer_config
[0].pad
= PADDING_VALID
;
414 RunCNNTest(image_width
, image_height
, input
, expected_valid
, &cnn_config
,
415 image_width
, &thread_data
, MSE_INT_TOL
);
418 TEST_F(CNNTest
, TestVaryingStridesVaryingDimImages
) {
420 1, -5, -3, -4, -1, 1, 2, -3, 2, 2, -1, 1, -5, 1, 1,
421 -3, -5, 3, 1, 4, -2, -5, -2, -3, -5, 0, -1, -5, 2, -2,
422 -2, 1, -2, -4, 1, 3, -2, 2, 0, -3, 2, -3, -2, -3,
424 float bias
[] = { 2 };
426 CNN_CONFIG cnn_config
= { 1,
454 int image_height
= 24;
455 int image_width
= 17;
457 -1, -3, 4, 4, -5, 4, 3, -5, -1, -3, 4, -4, 2, -3, 3, -5, 2, -1, -5,
458 1, -1, 3, 1, -3, -3, 4, 0, 2, -3, -5, -5, -4, 0, -5, -2, -3, -1, -2,
459 2, -5, 4, 4, 0, -4, -3, 1, -3, -5, -4, -4, 1, -2, -3, 3, -3, -3, -1,
460 -5, -5, -2, 3, 1, -1, -5, -5, 1, -4, -2, -1, -2, -4, -4, 2, -2, 2, 1,
461 -2, -4, -1, 1, -2, -5, 3, -2, -1, -1, -5, -3, 1, -2, -2, -3, -1, -2, -4,
462 -2, 1, -4, -1, 4, 3, -4, 0, 4, 2, 2, 4, -3, -5, 2, 2, 1, -1, -4,
463 -2, 1, 3, 2, 0, 4, -1, -3, 2, 1, -4, 2, 2, -4, -2, 0, -2, -1, 4,
464 4, 2, 3, -4, 2, -4, -5, 4, -1, -3, -1, 0, -4, 1, 3, -1, -3, -5, 3,
465 -2, -4, 1, 2, -2, -3, -3, -5, 1, -3, -1, 0, -1, 3, -4, -1, -5, -5, 1,
466 0, 0, -2, -2, 2, -2, 0, 0, 2, 0, -3, 0, -1, -4, -4, -1, 3, -4, -4,
467 -1, 0, -5, -3, -2, 4, -3, -4, -4, 0, -5, 1, -2, -3, -3, -4, 4, 3, 4,
468 3, 3, -1, 3, 1, -3, -2, 3, 3, 0, 2, -4, -3, 2, 2, 0, -2, 4, -2,
469 2, -2, -1, -4, -2, 2, -4, 3, -1, 4, 1, 1, 4, -1, -4, -4, 1, 1, -2,
470 4, -1, 3, 2, -3, 4, 3, 1, 4, 0, -4, 2, 0, 2, 4, -2, -2, 4, 2,
471 -1, -2, 1, -3, 2, 3, -5, -3, 4, 4, 2, -5, -4, -5, -2, -4, 2, 0, 2,
472 -5, 4, -4, -2, -5, 2, 1, 0, 4, 1, -2, -3, -4, -3, -4, 3, 3, 2, 0,
473 -3, 1, -5, 4, 0, 4, -1, 3, -5, -5, -2, -1, -1, 4, 3, 3, 4, 3, -4,
474 4, -3, -3, -1, -4, -1, -4, -1, -2, 4, -2, -4, 4, 4, -3, -4, -1, 1, 2,
475 -1, -2, -2, 3, 2, 2, -3, 0, -1, 0, 3, 2, -5, 0, -4, 0, 0, 2, -4,
476 -1, -1, 0, -2, 0, 1, 0, 0, 4, -5, -1, -5, 2, -1, 0, 2, -1, 1, 3,
477 -3, -5, -2, -3, 4, -2, -2, -1, -3, -4, -1, -2, -4, 1, 4, -3, -2, -1, 3,
478 -3, -2, 3, 2, 1, -4, -3, -5, 1,
480 float expected_1
[] = {
481 41, -26, 5, 76, 13, 83, -21, 53, -54, -14, 21, 121,
484 CNN_THREAD_DATA thread_data
= { 1, nullptr };
486 RunCNNTest(image_width
, image_height
, input
, expected_1
, &cnn_config
,
487 image_width
, &thread_data
, MSE_INT_TOL
);
489 cnn_config
.layer_config
[0].skip_width
= 6;
490 cnn_config
.layer_config
[0].skip_height
= 7;
492 float expected_2
[] = {
493 21, -50, 41, 20, 72, 127, -21, 103, 62, -37, 83, -3,
495 RunCNNTest(image_width
, image_height
, input
, expected_2
, &cnn_config
,
496 image_width
, &thread_data
, MSE_INT_TOL
);
498 cnn_config
.layer_config
[0].skip_width
= 3;
499 cnn_config
.layer_config
[0].skip_height
= 10;
501 float expected_3
[] = {
502 -26, -21, -35, 69, 49, 4, -51, -43, -56,
503 -41, 15, -44, 40, -62, 63, 38, 27, 47,
505 RunCNNTest(image_width
, image_height
, input
, expected_3
, &cnn_config
,
506 image_width
, &thread_data
, MSE_INT_TOL
);
508 cnn_config
.layer_config
[0].skip_width
= 10;
509 cnn_config
.layer_config
[0].skip_height
= 3;
511 float expected_4
[] = {
512 21, 49, 28, 87, 50, 40, 102, 81, 58, 85, 51, 66, 36, 19, -37, -45,
515 RunCNNTest(image_width
, image_height
, input
, expected_4
, &cnn_config
,
516 image_width
, &thread_data
, MSE_INT_TOL
);
519 TEST_F(CNNTest
, TestMaxPool
) {
521 int image_height
= 8;
524 1, -4, -4, 8, 0, 7, -5, -2, 8, 2, 2, 8, 5, -1, -1, 9,
525 -3, 0, -2, 0, 6, 3, -4, 8, 7, 8, 7, -1, 4, -1, 0, 2,
526 -5, -2, 8, 5, 5, 4, 2, 7, 4, 6, 2, 8, 8, -4, -3, -4,
527 -3, -1, 2, 3, 3, 6, -5, 8, 9, 5, 0, -2, -1, 6, 5, 7,
531 49, 58, 70, 68, 68, 70, 48, 57, 88,
535 3, 1, 3, 4, -1, 5, -2, 1, -4,
542 CNN_CONFIG cnn_config
= { 1,
568 CNN_THREAD_DATA thread_data
= { 1, nullptr };
570 RunCNNTest(image_width
, image_height
, input
, expected
, &cnn_config
,
571 image_width
, &thread_data
, MSE_INT_TOL
);
574 TEST_F(CNNTest
, TestDeconvolveNonActivationSingleLayerSingleKernel
) {
576 int image_height
= 7;
578 9, 6, 181, 9, 218, 30, 80, 108, 68, 216, 70, 128, 179, 228,
579 33, 212, 34, 14, 48, 27, 230, 23, 202, 113, 80, 56, 122, 112,
582 float expected_1_same
[] = {
583 15, -30, 36, -525, 377, -193, 558, 531, 6, -24, -15, 124,
584 166, -561, -356, -754, -3, -3, -3, -3, -3, -3, -3, -3,
585 433, -311, 711, 381, 247, -317, 453, 129, 215, -627, -409, -885,
586 17, -255, -55, -647, -3, -3, -3, -3, -3, -3, -3, -3,
587 133, -719, 633, -225, 785, 191, 463, 79, 65, 9, 77, -853,
588 -365, -949, -15, -667, -3, -3, -3, -3, -3, -3, -3, -3,
589 355, -866, 990, 207, 747, 12, 520, -116, 176, -312, -133, -1370,
590 -426, -802, 143, -771, -3, -3, -3, -3, -3, -3, -3, -3,
591 65, -79, 127, -59, 135, -90, 195, 114, 31, -91, -57, -133,
592 17, -176, -72, -276, -3, -3, -3, -3, -3, -3, -3, -3,
593 457, -302, 733, 58, 470, -475, 829, 490, 227, -670, -440, -790,
594 153, -588, -294, -1150, -3, -3, -3, -3, -3, -3, -3, -3,
595 157, -251, 349, -185, 409, -293, 587, 251, 77, -187, -107, -369,
596 7, -481, -135, -827, -3, -3, -3, -3, -3, -3, -3, -3,
598 float expected_1_valid
[] = {
599 -30, 15, -30, 36, -525, 377, -193, 558, 531, 24, 24, 6,
600 6, -24, -15, 124, 166, -561, -356, -754, -21, -39, -3, -3,
601 -3, -3, -3, -3, -3, -3, -3, -3, -3, -657, 433, -311,
602 711, 381, 247, -317, 453, 129, 321, 321, 215, 215, -627, -409,
603 -885, 17, -255, -55, -647, -219, -435, -3, -3, -3, -3, -3,
604 -3, -3, -3, -3, -3, -3, -207, 133, -719, 633, -225, 785,
605 191, 463, 79, 381, 381, 65, 65, 9, 77, -853, -365, -949,
606 -15, -667, -259, -515, -3, -3, -3, -3, -3, -3, -3, -3,
607 -3, -3, -3, -540, 355, -866, 990, 207, 747, 12, 520, -116,
608 633, 633, 176, 176, -312, -133, -1370, -426, -802, 143, -771, -427,
609 -851, -3, -3, -3, -3, -3, -3, -3, -3, -3, -3, -3,
610 -105, 65, -79, 127, -59, 135, -90, 195, 114, 78, 78, 31,
611 31, -91, -57, -133, 17, -176, -72, -276, -57, -111, -3, -3,
612 -3, -3, -3, -3, -3, -3, -3, -3, -3, -693, 457, -302,
613 733, 58, 470, -475, 829, 490, 336, 336, 227, 227, -670, -440,
614 -790, 153, -588, -294, -1150, -229, -455, -3, -3, -3, -3, -3,
615 -3, -3, -3, -3, -3, -3, -243, 157, -251, 349, -185, 409,
616 -293, 587, 251, 333, 333, 77, 77, -187, -107, -369, 7, -481,
617 -135, -827, -227, -451,
619 float weights_1
[] = { -3, 2, -1, 3, 3, 1, 1, -3, -2, -4 };
620 float bias_1
[] = { -3 };
622 CNN_CONFIG cnn_config
= { 1,
648 CNN_THREAD_DATA thread_data
= { 1, nullptr };
650 RunCNNTest(image_width
, image_height
, input
, expected_1_same
, &cnn_config
,
651 image_width
, &thread_data
, MSE_INT_TOL
);
653 // Change padding to valid
654 cnn_config
.layer_config
[0].pad
= PADDING_VALID
;
656 RunCNNTest(image_width
, image_height
, input
, expected_1_valid
, &cnn_config
,
657 image_width
, &thread_data
, MSE_INT_TOL
);
659 float expected_12_same
[] = {
660 15, -12, 6, 36, -9, -528, 377, -184, 513, 558, -12, 24,
661 6, -30, -15, -33, -21, 166, 154, -546, -356, -718, -30, -21,
662 433, -221, 561, 711, -33, -153, 247, -83, -87, 453, -111, 321,
663 215, -657, -409, -845, -93, 17, -43, -243, -55, -215, -327, -219,
664 133, -71, -447, 633, -219, 435, 785, -73, -177, 463, -131, 381,
665 65, -207, 77, -59, -651, -365, -797, -213, -15, -155, -387, -259,
666 355, -182, -150, 990, -231, 582, 747, -36, -540, 520, -215, 633,
667 176, -540, -133, -491, -687, -426, -882, -102, 143, 77, -639, -427,
668 65, -37, 57, 127, -17, -105, 135, -51, 60, 195, -30, 78,
669 31, -105, -57, -125, -45, 17, -11, -147, -72, -168, -84, -57,
670 457, -233, 618, 733, -26, -540, 470, -205, 264, 829, -116, 336,
671 227, -693, -440, -900, -72, 153, 107, -609, -294, -698, -342, -229,
672 157, -83, 69, 349, -59, -201, 409, -125, 27, 587, -115, 333,
673 77, -243, -107, -267, -171, 7, -105, -369, -135, -379, -339, -227,
675 float expected_12_valid
[] = {
676 -30, 15, -12, 6, 36, -9, -528, 377, -184, 513, 558, -12,
677 24, 24, 6, 6, -30, -15, -33, -21, 166, 154, -546, -356,
678 -718, -30, -21, -39, -657, 433, -221, 561, 711, -33, -153, 247,
679 -83, -87, 453, -111, 321, 321, 215, 215, -657, -409, -845, -93,
680 17, -43, -243, -55, -215, -327, -219, -435, -207, 133, -71, -447,
681 633, -219, 435, 785, -73, -177, 463, -131, 381, 381, 65, 65,
682 -207, 77, -59, -651, -365, -797, -213, -15, -155, -387, -259, -515,
683 -540, 355, -182, -150, 990, -231, 582, 747, -36, -540, 520, -215,
684 633, 633, 176, 176, -540, -133, -491, -687, -426, -882, -102, 143,
685 77, -639, -427, -851, -105, 65, -37, 57, 127, -17, -105, 135,
686 -51, 60, 195, -30, 78, 78, 31, 31, -105, -57, -125, -45,
687 17, -11, -147, -72, -168, -84, -57, -111, -693, 457, -233, 618,
688 733, -26, -540, 470, -205, 264, 829, -116, 336, 336, 227, 227,
689 -693, -440, -900, -72, 153, 107, -609, -294, -698, -342, -229, -455,
690 -243, 157, -83, 69, 349, -59, -201, 409, -125, 27, 587, -115,
691 333, 333, 77, 77, -243, -107, -267, -171, 7, -105, -369, -135,
692 -379, -339, -227, -451,
695 // Change skip_width, skip_height to {2, 3}
696 cnn_config
.layer_config
[0].skip_width
= 3;
697 cnn_config
.layer_config
[0].skip_height
= 2;
698 // Set padding to same
699 cnn_config
.layer_config
[0].pad
= PADDING_SAME_ZERO
;
701 RunCNNTest(image_width
, image_height
, input
, expected_12_same
, &cnn_config
,
702 image_width
, &thread_data
, MSE_INT_TOL
);
704 // Change padding to valid
705 cnn_config
.layer_config
[0].pad
= PADDING_VALID
;
706 RunCNNTest(image_width
, image_height
, input
, expected_12_valid
, &cnn_config
,
707 image_width
, &thread_data
, MSE_INT_TOL
);
709 cnn_config
.layer_config
[0].filter_width
= 4;
710 cnn_config
.layer_config
[0].filter_height
= 3;
711 float weights_2
[] = { -1, -3, -1, -3, 0, 2, -2, 4, 3, 0, 1, 4 };
712 float bias_2
[] = { -4 };
713 cnn_config
.layer_config
[0].weights
= weights_2
;
714 cnn_config
.layer_config
[0].bias
= bias_2
;
716 cnn_config
.layer_config
[0].skip_width
= 5;
717 cnn_config
.layer_config
[0].skip_height
= 2;
718 float expected_2_same
[] = {
719 -13, -31, -13, -31, -4, -10, -22, -10, -22, -4, -185, -547,
720 -185, -547, -4, -13, -31, -13, -31, -4, -4, 14, -22, 32,
721 -4, -4, 8, -16, 20, -4, -4, 358, -366, 720, -4, -4,
722 14, -22, 32, -4, -195, -658, -213, -622, -4, -16, -94, -28,
723 -70, -4, 459, -244, 97, 480, -4, -85, -328, -103, -292, -4,
724 -4, 432, -440, 868, -4, -4, 56, -64, 116, -4, -4, 156,
725 -164, 316, -4, -4, 212, -220, 428, -4, 582, -208, 146, 664,
726 -4, -130, -652, -190, -532, -4, 166, -214, 6, 106, -4, 192,
727 -388, -24, 44, -4, -4, 132, -140, 268, -4, -4, 428, -436,
728 860, -4, -4, 136, -144, 276, -4, -4, 252, -260, 508, -4,
729 21, -541, -115, -269, -4, 416, -688, -16, 176, -4, 173, -103,
730 33, 177, -4, 168, -640, -88, -128, -4, -4, 354, -362, 712,
731 -4, -4, 452, -460, 908, -4, -4, 62, -70, 128, -4, -4,
732 420, -428, 844, -4, 499, -106, 141, 610, -4, 666, -46, 210,
733 866, -4, 47, -148, -19, -16, -4, 605, -85, 181, 763, -4,
734 -4, 64, -72, 132, -4, -4, 24, -32, 52, -4, -4, 92,
735 -100, 188, -4, -4, 50, -58, 104, -4, -132, -694, -200, -558,
736 -4, 15, -73, -13, -17, -4, -62, -610, -158, -418, -4, -36,
737 -343, -90, -235, -4, -4, 456, -464, 916, -4, -4, 42, -50,
738 88, -4, -4, 400, -408, 804, -4, -4, 222, -230, 448, -4,
739 606, -244, 146, 676, -4, 9, -172, -37, -80, -4, 480, -370,
740 76, 438, -4, 223, -340, -3, 112, -4, -4, 156, -164, 316,
741 -4, -4, 108, -116, 220, -4, -4, 240, -248, 484, -4, -4,
744 float expected_2_valid
[] = {
745 -13, -31, -13, -31, -4, -10, -22, -10, -22, -4, -185, -547,
746 -185, -547, -4, -13, -31, -13, -31, -4, 14, -22, 32, -4,
747 -4, 8, -16, 20, -4, -4, 358, -366, 720, -4, -4, 14,
748 -22, 32, -195, -658, -213, -622, -4, -16, -94, -28, -70, -4,
749 459, -244, 97, 480, -4, -85, -328, -103, -292, -4, 432, -440,
750 868, -4, -4, 56, -64, 116, -4, -4, 156, -164, 316, -4,
751 -4, 212, -220, 428, 582, -208, 146, 664, -4, -130, -652, -190,
752 -532, -4, 166, -214, 6, 106, -4, 192, -388, -24, 44, -4,
753 132, -140, 268, -4, -4, 428, -436, 860, -4, -4, 136, -144,
754 276, -4, -4, 252, -260, 508, 21, -541, -115, -269, -4, 416,
755 -688, -16, 176, -4, 173, -103, 33, 177, -4, 168, -640, -88,
756 -128, -4, 354, -362, 712, -4, -4, 452, -460, 908, -4, -4,
757 62, -70, 128, -4, -4, 420, -428, 844, 499, -106, 141, 610,
758 -4, 666, -46, 210, 866, -4, 47, -148, -19, -16, -4, 605,
759 -85, 181, 763, -4, 64, -72, 132, -4, -4, 24, -32, 52,
760 -4, -4, 92, -100, 188, -4, -4, 50, -58, 104, -132, -694,
761 -200, -558, -4, 15, -73, -13, -17, -4, -62, -610, -158, -418,
762 -4, -36, -343, -90, -235, -4, 456, -464, 916, -4, -4, 42,
763 -50, 88, -4, -4, 400, -408, 804, -4, -4, 222, -230, 448,
764 606, -244, 146, 676, -4, 9, -172, -37, -80, -4, 480, -370,
765 76, 438, -4, 223, -340, -3, 112, -4, 156, -164, 316, -4,
766 -4, 108, -116, 220, -4, -4, 240, -248, 484, -4, -4, 220,
767 -228, 444, 236, -4, 76, 316, -4, 164, -4, 52, 220, -4,
768 362, -4, 118, 484, -4, 332, -4, 108, 444,
770 // Set padding to same
771 cnn_config
.layer_config
[0].pad
= PADDING_SAME_ZERO
;
773 RunCNNTest(image_width
, image_height
, input
, expected_2_same
, &cnn_config
,
774 image_width
, &thread_data
, MSE_INT_TOL
);
776 cnn_config
.layer_config
[0].pad
= PADDING_VALID
;
778 RunCNNTest(image_width
, image_height
, input
, expected_2_valid
, &cnn_config
,
779 image_width
, &thread_data
, MSE_INT_TOL
);
781 cnn_config
.layer_config
[0].skip_width
= 2;
782 cnn_config
.layer_config
[0].skip_height
= 5;
783 float expected_21_same
[] = {
784 -31, -19, -49, -191, -565, -194, -574, -13, 14, -22, 44, -16,
785 382, -366, 738, -22, -4, 23, 32, 545, 20, 204, 720, 5,
786 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
787 -4, -4, -4, -4, -658, -252, -748, -114, -334, -192, -568, -112,
788 432, -440, 928, -64, 276, -164, 532, -220, -4, 304, 868, 266,
789 116, 400, 316, 104, -4, -4, -4, -4, -4, -4, -4, -4,
790 -4, -4, -4, -4, -4, -4, -4, -4, -208, -288, -856, -290,
791 -862, -202, -598, -132, 132, -140, 700, -436, 1000, -144, 532, -260,
792 -4, 712, 268, 422, 860, 450, 276, 124, -4, -4, -4, -4,
793 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
794 -541, -411, -1225, -265, -787, -249, -739, -216, 354, -362, 1168, -460,
795 974, -70, 552, -428, -4, 859, 712, 323, 908, 665, 128, 208,
796 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
797 -4, -4, -4, -4, -106, -52, -148, -66, -190, -79, -229, -31,
798 64, -72, 160, -32, 148, -100, 242, -58, -4, 72, 132, 154,
799 52, 125, 188, 23, -4, -4, -4, -4, -4, -4, -4, -4,
800 -4, -4, -4, -4, -4, -4, -4, -4, -694, -257, -763, -229,
801 -679, -319, -949, -117, 456, -464, 962, -50, 492, -408, 1030, -230,
802 -4, 295, 916, 625, 88, 537, 804, 109, -4, -4, -4, -4,
803 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
804 -244, -140, -412, -182, -538, -238, -706, -116, 156, -164, 428, -116,
805 464, -248, 708, -228, -4, 244, 316, 418, 220, 454, 484, 108,
806 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
809 float expected_21_valid
[] = {
810 -13, -31, -19, -49, -191, -565, -194, -574, -13, -31, -4, 14,
811 -22, 44, -16, 382, -366, 738, -22, 32, 23, -4, 23, 32,
812 545, 20, 204, 720, 5, 32, -4, -4, -4, -4, -4, -4,
813 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
814 -4, -4, -222, -658, -252, -748, -114, -334, -192, -568, -112, -328,
815 -4, 432, -440, 928, -64, 276, -164, 532, -220, 428, 650, -4,
816 304, 868, 266, 116, 400, 316, 104, 428, -4, -4, -4, -4,
817 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
818 -4, -4, -4, -4, -72, -208, -288, -856, -290, -862, -202, -598,
819 -132, -388, -4, 132, -140, 700, -436, 1000, -144, 532, -260, 508,
820 200, -4, 712, 268, 422, 860, 450, 276, 124, 508, -4, -4,
821 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
822 -4, -4, -4, -4, -4, -4, -183, -541, -411, -1225, -265, -787,
823 -249, -739, -216, -640, -4, 354, -362, 1168, -460, 974, -70, 552,
824 -428, 844, 533, -4, 859, 712, 323, 908, 665, 128, 208, 844,
825 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
826 -4, -4, -4, -4, -4, -4, -4, -4, -38, -106, -52, -148,
827 -66, -190, -79, -229, -31, -85, -4, 64, -72, 160, -32, 148,
828 -100, 242, -58, 104, 98, -4, 72, 132, 154, 52, 125, 188,
829 23, 104, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
830 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -234, -694,
831 -257, -763, -229, -679, -319, -949, -117, -343, -4, 456, -464, 962,
832 -50, 492, -408, 1030, -230, 448, 686, -4, 295, 916, 625, 88,
833 537, 804, 109, 448, -4, -4, -4, -4, -4, -4, -4, -4,
834 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
835 -84, -244, -140, -412, -182, -538, -238, -706, -116, -340, -4, 156,
836 -164, 428, -116, 464, -248, 708, -228, 444, 236, -4, 244, 316,
837 418, 220, 454, 484, 108, 444,
840 cnn_config
.layer_config
[0].pad
= PADDING_SAME_ZERO
;
842 RunCNNTest(image_width
, image_height
, input
, expected_21_same
, &cnn_config
,
843 image_width
, &thread_data
, MSE_INT_TOL
);
845 cnn_config
.layer_config
[0].pad
= PADDING_VALID
;
847 RunCNNTest(image_width
, image_height
, input
, expected_21_valid
, &cnn_config
,
848 image_width
, &thread_data
, MSE_INT_TOL
);
851 TEST_F(CNNTest
, TestLargeKernelsAndStrides
) {
852 float input_10x11
[] = {
853 4, 4, 2, 4, 2, -5, -2, 3, -1, 0, 0, 1, 2, 0, -5, -2, -5, 1, -3,
854 -1, 4, -3, 2, -2, 1, 0, 1, -3, -3, -4, -2, -2, 1, -4, -1, 4, 1, -4,
855 -4, -4, 3, 2, -5, 3, -5, 1, 2, -4, 1, -1, 3, 4, -2, 3, -3, 3, 0,
856 2, -4, -5, -5, -2, -1, -2, 1, 1, 1, -2, 4, -5, 4, -1, -1, 2, 3, -4,
857 2, 2, 3, 0, 0, 1, 0, 3, 2, 3, 1, -2, 3, -4, 3, 2, 4, -2, 0,
858 4, -4, 1, -3, -3, -3, -5, 1, -3, -5, 0, 4, -1, -3, 2,
861 float weights_10x11
[] = {
862 -3, 4, -4, -3, -5, 1, -2, 3, 1, -4, -4, 0, -1, 0, 3, 1, -3, -2, 0,
863 -1, 1, 3, -4, -4, -3, -3, -2, 4, 3, -5, 4, 2, -3, 4, -2, -1, 2, -1,
864 -5, 0, -3, 0, 3, -5, -5, 3, -4, -1, -5, 3, 4, 0, 4, -5, 2, -1, 2,
865 -1, -1, -1, -5, 0, -4, 3, -1, 1, 1, -1, 3, 2, -5, -4, 0, -4, 4, -5,
866 -3, 4, -5, 2, -5, -4, -4, -1, 3, 3, 0, 2, -4, 1, -2, 1, 1, 0, 3,
867 -2, 0, 1, 2, 4, -3, -1, -5, -5, 2, -4, 1, 1, 2, -4, -2, -2, 2, 1,
868 3, 4, -5, 1, -1, -3, -3, -1, -2, -5, 1, -1, 0, 1, 4, 4, 0, 0, 4,
869 -3, -1, -5, -3, 0, 1, 1, 1, -5, 3, 4, 3, -5, 3, -2, -2, 0, -4, 0,
870 0, -2, 1, -4, -1, 0, -5, -2, -2, -5, -3, -3, 1, 1, -3, 2, 4, 2, 4,
871 -4, -3, 3, 1, 1, 3, -4, 4, -2, -3, -3, -3, -3, -4, -2, 3, -5, 2, 4,
872 -1, -4, -4, 4, -2, -1, 3, -3, -4, -4, -2, 4, 1, 0, 2, -1, 4, -3, 1,
873 4, -3, 4, 4, 0, -4, 3, -2, -3, 2, 3, -1, -3, 2, 1, 4, -2, -3, 1,
874 4, -2, 2, -2, -5, -2, 1, 4, -1, -4, 4, -5, 2, -5, -4, -1, -2, 3, 1,
875 2, 1, -5, 1, -5, -4, -1, -2, 2, -2, -4, -3, -2, -2, 4, -1, 2, 2, -4,
876 2, -2, 4, -4, -2, -2, 1, -1, 1, 1, 1, -4, -5, -2, 3, -4, -1, 3, -2,
877 3, 2, -5, -4, 0, 3, -2, -4, -5, 3, -2, -4, 2, -2, 1, -4, 0, 2, -5,
878 1, -4, -1, -1, 4, -5, -4, 0, -5, -4, -3, -5, -4, 0, 2, 0, -4, 2, -2,
879 1, 1, -3, 2, 0, -4, 0, -4, 1, 0, -5, -1, -1, -1, -5, 4, 2, 2, -4,
880 3, -2, -2, 2, -3, -2, -1, 2, -4, -5, 2, -2, -4, -5, -5, -1, 2, -1, 0,
881 -5, -2, -2, -5, 0, 1, -1, -5, 0, 3, 2, 3, 0, -3, -2, 0, -5, -1, -2,
882 2, -4, -1, 2, 2, -5, 2, -4, 0, 3, -3, 1, 0, 0, 1, -5, -3, 1, -1,
883 0, -4, -3, 2, -4, -4, 4, -1, 0, 1, 2, -4, -5, 4, -2, 1, -4, -4, -3,
884 -1, -1, 1, -1, -4, -1, -4, -3, 2, -1, -2, -4, 1, 1, 0, -2, 0, -4, 3,
885 -3, 0, -4, -1, -4, 2, -1, -2, -5, -1, -2, -3, 3, -1, 0, -3, 0, 1, -5,
889 float bias_10x11
[] = { 3 };
891 float expected_10x11
[] = {
895 CNN_CONFIG cnn_config
= { 1,
921 int image_height
= 10;
922 int image_width
= 11;
924 CNN_THREAD_DATA thread_data
= { 1, nullptr };
926 RunCNNTest(image_width
, image_height
, input_10x11
, expected_10x11
,
927 &cnn_config
, image_width
, &thread_data
, MSE_INT_TOL
);
929 float input_11x10
[] = {
930 -2, -2, 3, -5, -1, -3, 1, 3, 2, 1, 1, -5, 4, 1, 3, -5, 3, -3, -5,
931 0, -1, -3, -3, 1, 1, -5, -1, -5, -5, -3, 0, 1, -3, -1, -3, -3, 0, 3,
932 4, -4, -1, 3, -3, -1, -3, 1, -3, -2, -1, -4, -3, 2, -4, 1, -4, -1, -3,
933 -5, -1, 2, 3, 0, 2, 2, -5, 4, 1, 2, -1, -4, 4, -4, -4, 0, -1, 1,
934 -1, 1, -3, -3, -2, 1, 2, 4, 4, 4, -3, -3, 0, 1, 0, 1, 4, 1, 3,
935 4, -3, -2, -4, 4, 2, 0, 3, 4, -1, 2, -2, 1, -3, -2,
938 float weights_11x10
[] = {
939 4, -1, 1, -1, 2, 4, 3, 3, -4, 3, -5, 1, -1, -1, -2, -2, 0, 2, -3,
940 -2, 3, -5, -1, 0, -1, -2, -2, -1, 2, 4, 3, 1, 0, 0, -3, 3, -4, -1,
941 -5, 4, -2, -2, 1, 2, -1, -3, 1, 2, -5, 1, -3, 3, 3, 0, -4, -4, -5,
942 -3, -4, -4, 4, -2, 4, 4, -2, 2, -5, -1, -2, -5, -1, 4, -3, 3, -2, 0,
943 -4, -3, 0, -1, -2, 4, 2, 0, -2, -5, -4, 1, 4, -4, -2, 2, -2, 1, 1,
944 -4, 1, -4, -4, -2, 4, 2, -1, -5, -5, 1, -3, -3, 3, -3, -5, -3, 4, -1,
945 -1, -3, 0, -4, 3, -1, 0, -2, 0, -5, -2, -5, 2, 0, -5, 2, 3, -2, 2,
946 4, -1, 1, -3, 2, 3, 2, 0, -5, -4, -5, 2, 1, 1, -1, -2, 3, 4, 2,
947 -2, 4, -2, 3, 1, -4, -3, -1, 4, 4, -3, -5, -2, 2, 0, 3, -2, 3, -1,
948 -4, 0, -2, 0, 3, 4, -2, -3, -2, 0, 3, 4, 2, -4, 0, 1, 2, 2, -1,
949 -1, 4, 1, 4, -2, -1, -1, -5, 1, -3, 3, 3, -1, -4, 3, -5, 0, 0, -1,
950 -4, -1, -2, 4, -2, 3, 3, -3, 1, -1, 2, -1, 4, 4, -2, -2, 4, -2, 0,
951 3, -3, -5, -1, -2, 4, -4, 2, -4, 0, -2, 3, -3, 2, 2, -2, -5, -1, 4,
952 3, -2, -1, 3, 3, -1, 3, 0, -3, 0, 4, 2, 0, -1, 4, 1, 1, 2, 1,
953 3, 1, 1, 1, -3, -5, -4, 4, -4, 2, 0, 0, -4, 1, 4, -5, 4, 4, 0,
954 1, 0, -2, -4, -4, -3, 0, 1, -5, 4, 0, -3, -2, -4, 2, 4, 1, -5, 1,
955 -4, 1, 0, -3, -3, 0, 2, -5, 4, 3, -2, -5, 3, 1, -1, 0, 3, -2, -2,
956 3, -2, -5, 4, 1, -2, 2, -1, 0, 4, 0, -5, 3, -2, 1, 2, 1, -5, -3,
957 -2, -5, 4, -4, 0, 3, 2, -1, -4, -1, 2, 1, -2, 3, -1, -4, 2, 0, -3,
958 1, -1, 2, -5, -4, -1, -5, 1, 4, 3, 4, 2, -3, 1, -5, -1, 3, 0, -1,
959 -4, 3, 4, -5, 4, 4, -3, 2, -3, -1, -3, -5, -3, 2, -3, -2, 1, 1, 0,
960 -5, 3, 2, 1, -5, 1, 1, 1, 3, 4, -4, -1, -2, 0, -5, -3, -5, -2, -4,
961 3, 3, 3, 4, 0, -4, -1, -5, 0, -3, 1, 4, 4, -4, 4, -5, -5, -1, -2,
962 -5, 3, -4, 4, 3, 0, -3, 2, -2, 0, 0, 4, 4, 0, -2, 1, -1, -3, 2,
966 float bias_11x10
[] = {
970 float expected_11x10
[] = {
971 36, -84, 95, 45, 18, 46, 77, -54, -99, -149, 66, 49, 161, 11,
972 39, 61, -66, 61, 4, -3, 34, -44, -23, 31, 64, 29, 47, 72,
973 -27, -27, 121, -3, 100, 1, 30, -78, -12, -89, -59, 8, -16, 112,
974 91, -102, -26, -4, 30, 54, 4, -84, -24, -58, 27, -53, -33, 5,
975 53, -26, 63, 50, -103, -130, -23, 6, -104, -207, 73, 23, 77, 132,
976 38, 32, -130, -44, -60, 7, 27, 176, 45, -32, -2, 99, -97, 63,
977 69, 126, 47, 63, 136, -57, 5, 16, -40, -157, 8, 38, -44, -10,
978 91, 7, 122, 140, 30, -105, 4, -1, 113, 64, 180, 141,
981 cnn_config
.layer_config
[0].weights
= weights_11x10
;
982 cnn_config
.layer_config
[0].bias
= bias_11x10
;
983 cnn_config
.layer_config
[0].filter_width
= 20;
984 cnn_config
.layer_config
[0].filter_height
= 23;
985 cnn_config
.layer_config
[0].skip_width
= 1;
986 cnn_config
.layer_config
[0].skip_height
= 1;
990 RunCNNTest(image_width
, image_height
, input_11x10
, expected_11x10
,
991 &cnn_config
, image_width
, &thread_data
, MSE_INT_TOL
);
994 TEST_F(CNNTest
, TestSoftsignSingleLayer
) {
996 int image_height
= 8;
997 int filter_height
= 5;
998 int filter_width
= 4;
1000 -0.5220f
, 0.8410f
, -0.8990f
, -0.0090f
, 0.6710f
, -0.9470f
, -0.8240f
,
1001 -0.0870f
, 0.5380f
, 0.4750f
, 0.570f
, -0.3760f
, -0.6960f
, -0.5940f
,
1002 -0.3830f
, 0.080f
, -0.0980f
, -0.4940f
, -0.4030f
, 0.9460f
, -0.6020f
,
1003 0.4220f
, 0.6190f
, 0.6640f
, -0.9210f
, -0.1470f
, -0.2480f
, -0.1120f
,
1004 -0.580f
, -0.0650f
, 0.3330f
, 0.9860f
, -0.7430f
, 0.7610f
, 0.4840f
,
1005 0.1030f
, 0.9570f
, 0.6120f
, -0.5240f
, -0.1220f
, -0.5850f
, -0.270f
,
1006 0.7840f
, -0.9790f
, 0.7290f
, -0.30f
, -0.6460f
, 0.0780f
, 0.4750f
,
1007 -0.0510f
, 0.4550f
, 0.3850f
, -0.7230f
, 0.4460f
, -0.6260f
, -0.810f
,
1008 0.8720f
, -0.2120f
, -0.580f
, -0.9510f
, -0.8430f
, -0.1340f
, -0.0850f
,
1011 float expected_same
[] = {
1012 0.430f
, 0.660f
, 0.5510f
, -0.610f
, 0.450f
, -0.1610f
, 0.0520f
, 0.3240f
,
1013 0.6820f
, 0.3820f
, 0.6360f
, 0.7480f
, 0.3080f
, 0.090f
, 0.3910f
, 0.1730f
,
1014 0.340f
, 0.6660f
, -0.4990f
, 0.4280f
, 0.1540f
, 0.120f
, 0.4670f
, 0.6150f
,
1015 -0.3880f
, 0.7590f
, 0.4190f
, 0.7350f
, 0.5310f
, -0.5160f
, -0.1760f
, 0.6790f
,
1016 -0.6780f
, 0.5470f
, 0.5750f
, -0.6420f
, 0.7210f
, -0.4620f
, 0.5430f
, 0.770f
,
1017 -0.1990f
, 0.3950f
, 0.7860f
, -0.4380f
, 0.7540f
, 0.2640f
, -0.6430f
, 0.4510f
,
1018 -0.1260f
, 0.1590f
, -0.2110f
, -0.0560f
, 0.6570f
, 0.680f
, 0.5870f
, 0.4720f
,
1019 0.4040f
, 0.3630f
, 0.670f
, 0.2360f
, 0.410f
, 0.6980f
, -0.5350f
, 0.3940f
,
1021 float expected_replicate
[] = {
1022 0.540f
, 0.7230f
, -0.3530f
, -0.2130f
, 0.7440f
, -0.4470f
, -0.6260f
,
1023 -0.2050f
, 0.7230f
, 0.4630f
, 0.5920f
, 0.7440f
, 0.6080f
, 0.3130f
,
1024 -0.5670f
, -0.4720f
, 0.5480f
, 0.6660f
, -0.4990f
, 0.4280f
, 0.1540f
,
1025 0.120f
, 0.3390f
, 0.6090f
, 0.4160f
, 0.7590f
, 0.4190f
, 0.7350f
,
1026 0.5310f
, -0.5160f
, -0.490f
, 0.4450f
, -0.610f
, 0.5470f
, 0.5750f
,
1027 -0.6420f
, 0.7210f
, -0.4620f
, 0.3150f
, 0.7370f
, -0.5820f
, 0.3950f
,
1028 0.7860f
, -0.4380f
, 0.7540f
, 0.2640f
, -0.7430f
, -0.5340f
, -0.6270f
,
1029 0.4430f
, 0.4730f
, 0.4570f
, 0.7450f
, 0.630f
, 0.2620f
, 0.3140f
,
1030 -0.1840f
, 0.1810f
, 0.7210f
, 0.2760f
, 0.6430f
, 0.6720f
, -0.4390f
,
1033 float expected_valid
[] = {
1034 0.6660f
, -0.4990f
, 0.4280f
, 0.1540f
, 0.120f
, 0.7590f
, 0.4190f
,
1035 0.7350f
, 0.5310f
, -0.5160f
, 0.5470f
, 0.5750f
, -0.6420f
, 0.7210f
,
1036 -0.4620f
, 0.3950f
, 0.7860f
, -0.4380f
, 0.7540f
, 0.2640f
,
1039 0.6210f
, 0.3710f
, -0.2770f
, -0.7230f
, -0.2450f
, 0.6770f
, 0.3080f
,
1040 -0.9880f
, -0.080f
, 0.7190f
, -0.6760f
, -0.0170f
, -0.8970f
, 0.8260f
,
1041 0.7390f
, -0.4550f
, -0.4260f
, -0.6330f
, 0.0880f
, -0.9390f
,
1047 CNN_CONFIG cnn_config
= { 1,
1073 CNN_THREAD_DATA thread_data
= { 1, nullptr };
1075 RunCNNTest(image_width
, image_height
, input
, expected_same
, &cnn_config
,
1076 image_width
, &thread_data
, MSE_FLOAT_TOL
);
1078 cnn_config
.layer_config
[0].pad
= PADDING_SAME_REPLICATE
;
1080 RunCNNTest(image_width
, image_height
, input
, expected_replicate
, &cnn_config
,
1081 image_width
, &thread_data
, MSE_FLOAT_TOL
);
1083 cnn_config
.layer_config
[0].pad
= PADDING_VALID
;
1085 RunCNNTest(image_width
, image_height
, input
, expected_valid
, &cnn_config
,
1086 image_width
, &thread_data
, MSE_FLOAT_TOL
);
1089 TEST_F(CNNTest
, TestBranchTensorAdd
) {
1090 int filter_width
= 2;
1091 int filter_height
= 3;
1093 int image_width
= 4;
1094 int image_height
= 4;
1097 -3, -2, -2, 0, -1, 3, 2, -2, 1, 3, 4, 0, 2, -5, -4, 0,
1101 -3, -1, 4, -1, -3, 3, 3, 0, 2, 0, 3, 2, 4, 4, 4, -5, 1, -4,
1102 2, -4, 1, -3, 0, 4, -5, 4, 0, -4, -3, -1, 0, 0, -2, 0, 0, 2,
1103 -5, -1, 1, -3, 3, 4, 3, 0, 1, -1, 1, 1, 2, 4, -2, -5, 2, -2,
1104 3, -2, 4, -1, 0, 2, 3, 2, -2, -1, -3, 1, 3, 4, -1, -3, 0, -4,
1105 4, 2, -3, -3, -1, 0, 1, 0, 3, 3, -3, 0, 3, 2, -5, -3, 4, -5,
1106 3, -1, -1, -3, 0, 1, -1, -4, 2, 4, -1, 4, -1, 1, 3, 4, 4, 4,
1107 0, -1, -3, -3, -3, -3, 2, -3, -2, 2, 3, -3,
1111 3, 4, -1, -1, 2, 1, -2, 1, 4, 1, 3,
1114 float expected
[] = {
1115 -11502, -4101, -3424, 668, -17950, -5470, -5504, 626,
1116 4835, 446, 1779, -3483, 3679, -4214, 4578, -105,
1121 CNN_CONFIG cnn_config
= { 6,
1255 // Weights and biases need to be specified separately because
1257 AssignLayerWeightsBiases(&cnn_config
, weights
, bias
);
1259 CNN_THREAD_DATA thread_data
= { 1, nullptr };
1261 RunCNNTest(image_width
, image_height
, input
, expected
, &cnn_config
,
1262 image_width
, &thread_data
, MSE_INT_TOL
);
1265 TEST_F(CNNTest
, TestBranchTensorConcatenation
) {
1266 int filter_width
= 2;
1267 int filter_height
= 3;
1269 int image_width
= 4;
1270 int image_height
= 4;
1273 -3, -2, -2, 0, -1, 3, 2, -2, 1, 3, 4, 0, 2, -5, -4, 0,
1277 3, 0, 2, 0, 2, 3, 1, -3, 1, -5, -3, 0, -4, 4, 0, -5, 0, -5, -1,
1278 -2, -5, 0, -3, 2, -4, 2, 0, 2, -1, 0, -4, 3, 0, 0, -1, -5, 2, -1,
1279 4, -4, -2, -3, -3, 3, 4, -2, -1, -4, -1, 4, 4, -1, 4, 3, -4, 2, -2,
1280 -4, -3, -2, 3, -3, -5, -1, 3, -2, 4, 1, -4, -3, -5, -5, -3, 4, -2, -2,
1281 -1, -5, -5, 0, -1, -2, -3, 3, -4, -5, 2, -3, 1, 0, -5, 2, 2, -2, 0,
1282 2, 2, -2, 4, 2, 2, 0, 1, -5, -3, 0, 2, -2, 1, 2, -5, 2, 3, 3,
1283 -1, 3, 0, -3, 3, -4, -4, 3, 3, -4, -2, 2, -2, 2, -2, -1, 3, 0,
1287 -3, -5, 4, -4, -3, -2, 0, 3, -4, 4, -3,
1290 float expected
[] = {
1291 -33533, -32087, -6741, -2124, 39979, 41453, 14034, 689,
1292 -22611, -42203, -14882, -239, 15781, 15963, 9524, 837,
1297 CNN_CONFIG cnn_config
= { 6,
1411 channels
+ channels
,
1431 // Weights and biases need to be specified separately because
1433 AssignLayerWeightsBiases(&cnn_config
, weights
, bias
);
1435 CNN_THREAD_DATA thread_data
= { 1, nullptr };
1437 RunCNNTest(image_width
, image_height
, input
, expected
, &cnn_config
,
1438 image_width
, &thread_data
, MSE_INT_TOL
);
1441 // TODO(logangw): Add test to test all combinations of branch_copy_type.
1443 TEST_F(CNNTest
, TestBranchCombinations
) {
1444 int filter_width
= 2;
1445 int filter_height
= 3;
1447 int image_width
= 4;
1448 int image_height
= 4;
1451 3, 2, -5, -4, 4, -2, -4, -3, 4, 2, -3, 2, -3, 1, -5, -1,
1455 2, 3, 0, 4, 4, 3, 1, 0, 1, -5, 4, -3, 3, 0, 4, -1, -1, -5,
1456 2, 1, -3, -5, 3, -1, -3, -2, 0, -2, 3, 0, -2, -4, -2, -2, 2, -5,
1457 4, -5, 0, 1, -5, -4, -3, -4, 2, -2, 1, 0, 3, -2, -4, 3, 4, -4,
1458 -1, -1, -3, -2, -2, -1, 2, 0, 2, -1, 2, -4, -4, -1, 2, 0, 3, -2,
1459 -2, 3, -3, 4, -2, 4, 3, 4, 1, 0, -2, -3, -5, 1, -3, 2, 0, -2,
1460 -2, -1, -1, -5, -2, -3, -1, 3, 3, 4, 4, 0, 2, 1, 3, -3, 2, -5,
1461 -5, 1, -5, -1, 3, 3, 2, -4, -1, 3, -4, -2, -5, -2, 1, 3, 2, 2,
1462 -5, -2, -3, -1, -2, -4, -1, -2, 2, 1, -4, -4, 2, 0, 2, 0, 2, -3,
1463 -2, -4, 4, 0, 1, -3, -5, 4, -1, 2, 3, -5, -1, 0, 4, -1, -1, 3,
1464 -1, -3, 3, 1, 4, 3, 4, 3, -4, -5, -1, 3, 3, -4, 3, 1, 3, -5,
1465 3, 4, -5, 4, 2, -1, -5, 2, 1, 0, 4, 0, -3, 2, 0, 2, -2, 1,
1466 -1, -2, -1, -5, 4, 3, 3, -2, 2, 4, -5, -5, -3, -2, 4, 0, -4, 1,
1470 -1, 4, 0, 2, 2, -2, 0, -4, -5, -1, 1, -2, 3, 0, 4, -2, 1, 0, 0,
1473 float expected
[] = {
1474 149496, 15553, -24193, -20956, 134094, 86432, -68283, -6366,
1475 -53031, 133739, 67407, -13539, -53205, -58635, -20033, 1979,
1480 CNN_CONFIG cnn_config
= { 10,
1708 // Weights and biases need to be specified separately because
1710 AssignLayerWeightsBiases(&cnn_config
, weights
, bias
);
1712 CNN_THREAD_DATA thread_data
= { 1, nullptr };
1714 RunCNNTest(image_width
, image_height
, input
, expected
, &cnn_config
,
1715 image_width
, &thread_data
, MSE_INT_TOL
);
1718 TEST_F(CNNTest
, TestSplittingTensors
) {
1719 int filter_width
= 2;
1720 int filter_height
= 3;
1722 int image_width
= 4;
1723 int image_height
= 4;
1726 -1, -1, 2, 1, 3, 2, 4, -3, -4, -2, 2, -3, 1, -3, 4, -2,
1730 -4, 1, 0, 2, 3, 4, 4, -4, -5, -3, 2, 2, -4, -3, 3, 2,
1731 4, -4, -3, -4, -4, 1, -3, -5, -3, 4, 2, -2, 2, -1, -4, -1,
1732 -2, -3, 1, 1, 0, -5, -1, 3, 3, -5, -3, 0, -3, 1, -3, -1,
1733 1, -3, -2, -2, 4, -2, 0, 1, 2, 2, -4, 2, 4, 0, -5, -2,
1734 4, 4, -5, 1, 0, 2, -2, -5, -5, -3, -5, -5, 4, -3, 0, 0,
1735 -4, -4, 0, -5, -4, 0, 0, -3, -5, -3, -1, 2, -1, 4, -1, 2,
1739 -4, -2, -3, -3, 3, 1, -2,
1742 float expected
[] = {
1743 530, -762, 1469, 777, 849, -771, -1698, 600,
1744 -658, -1821, 98, -668, -1798, 30, 887, -971,
1747 CNN_CONFIG cnn_config
= { 3,
1823 // Weights and biases need to be specified separately because
1825 AssignLayerWeightsBiases(&cnn_config
, weights
, bias
);
1827 CNN_THREAD_DATA thread_data
= { 1, nullptr };
1829 RunCNNTest(image_width
, image_height
, input
, expected
, &cnn_config
,
1830 image_width
, &thread_data
, MSE_INT_TOL
);
1833 TEST_F(CNNTest
, TestOutputChannelsCount
) {
1834 int filter_width
= 1;
1835 int filter_height
= 1;
1837 int image_width
= 2;
1838 int image_height
= 2;
1840 float input
[] = { 0, 0, 0, 0 };
1842 float weights
[] = { 0, 0, 0, 0, 0, 0, 0, 0 };
1844 float bias
[] = { 0, 0, 0, 0, 0, 0 };
1846 float expected
[] = {
1847 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1848 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1851 CNN_CONFIG cnn_config
= { 3,
1931 // Weights and biases need to be specified separately because
1933 AssignLayerWeightsBiases(&cnn_config
, weights
, bias
);
1935 CNN_THREAD_DATA thread_data
= { 1, nullptr };
1937 RunCNNTest(image_width
, image_height
, input
, expected
, &cnn_config
,
1938 image_width
, &thread_data
, MSE_FLOAT_TOL
);
1941 TEST_F(CNNTest
, TestBatchNorm
) {
1942 int image_width
= 28;
1943 int image_height
= 28;
1944 int filter_height
= 7;
1945 int filter_width
= 7;
1947 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
1948 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
1949 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
1950 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
1951 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
1952 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
1953 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
1954 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
1955 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
1956 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
1957 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
1958 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
1959 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
1960 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
1961 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
1962 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
1963 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
1964 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
1965 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
1966 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
1967 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
1968 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
1969 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
1970 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
1971 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
1972 0.0f
, 0.0f
, 0.0117647f
, 0.0705882f
, 0.0705882f
, 0.0705882f
,
1973 0.494118f
, 0.533333f
, 0.686275f
, 0.101961f
, 0.65098f
, 1.0f
,
1974 0.968627f
, 0.498039f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
1975 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
1976 0.0f
, 0.0f
, 0.117647f
, 0.141176f
, 0.368627f
, 0.603922f
,
1977 0.666667f
, 0.992157f
, 0.992157f
, 0.992157f
, 0.992157f
, 0.992157f
,
1978 0.882353f
, 0.67451f
, 0.992157f
, 0.94902f
, 0.764706f
, 0.25098f
,
1979 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
1980 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.192157f
,
1981 0.933333f
, 0.992157f
, 0.992157f
, 0.992157f
, 0.992157f
, 0.992157f
,
1982 0.992157f
, 0.992157f
, 0.992157f
, 0.984314f
, 0.364706f
, 0.321569f
,
1983 0.321569f
, 0.219608f
, 0.152941f
, 0.0f
, 0.0f
, 0.0f
,
1984 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
1985 0.0f
, 0.0f
, 0.0f
, 0.0705882f
, 0.858824f
, 0.992157f
,
1986 0.992157f
, 0.992157f
, 0.992157f
, 0.992157f
, 0.776471f
, 0.713725f
,
1987 0.968627f
, 0.945098f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
1988 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
1989 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
1990 0.0f
, 0.0f
, 0.313725f
, 0.611765f
, 0.419608f
, 0.992157f
,
1991 0.992157f
, 0.803922f
, 0.0431373f
, 0.0f
, 0.168627f
, 0.603922f
,
1992 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
1993 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
1994 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
1995 0.0f
, 0.054902f
, 0.00392157f
, 0.603922f
, 0.992157f
, 0.352941f
,
1996 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
1997 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
1998 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
1999 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2000 0.0f
, 0.545098f
, 0.992157f
, 0.745098f
, 0.00784314f
, 0.0f
,
2001 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2002 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2003 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2004 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0431373f
,
2005 0.745098f
, 0.992157f
, 0.27451f
, 0.0f
, 0.0f
, 0.0f
,
2006 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2007 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2008 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2009 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.137255f
, 0.945098f
,
2010 0.882353f
, 0.627451f
, 0.423529f
, 0.00392157f
, 0.0f
, 0.0f
,
2011 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2012 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2013 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2014 0.0f
, 0.0f
, 0.0f
, 0.317647f
, 0.941176f
, 0.992157f
,
2015 0.992157f
, 0.466667f
, 0.0980392f
, 0.0f
, 0.0f
, 0.0f
,
2016 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2017 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2018 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2019 0.0f
, 0.0f
, 0.176471f
, 0.729412f
, 0.992157f
, 0.992157f
,
2020 0.588235f
, 0.105882f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2021 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2022 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2023 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2024 0.0f
, 0.0627451f
, 0.364706f
, 0.988235f
, 0.992157f
, 0.733333f
,
2025 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2026 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2027 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2028 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2029 0.0f
, 0.976471f
, 0.992157f
, 0.976471f
, 0.25098f
, 0.0f
,
2030 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2031 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2032 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2033 0.0f
, 0.0f
, 0.180392f
, 0.509804f
, 0.717647f
, 0.992157f
,
2034 0.992157f
, 0.811765f
, 0.00784314f
, 0.0f
, 0.0f
, 0.0f
,
2035 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2036 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2037 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.152941f
, 0.580392f
,
2038 0.898039f
, 0.992157f
, 0.992157f
, 0.992157f
, 0.980392f
, 0.713725f
,
2039 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2040 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2041 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2042 0.0941176f
, 0.447059f
, 0.866667f
, 0.992157f
, 0.992157f
, 0.992157f
,
2043 0.992157f
, 0.788235f
, 0.305882f
, 0.0f
, 0.0f
, 0.0f
,
2044 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2045 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2046 0.0f
, 0.0f
, 0.0901961f
, 0.258824f
, 0.835294f
, 0.992157f
,
2047 0.992157f
, 0.992157f
, 0.992157f
, 0.776471f
, 0.317647f
, 0.00784314f
,
2048 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2049 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2050 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0705882f
, 0.670588f
,
2051 0.858824f
, 0.992157f
, 0.992157f
, 0.992157f
, 0.992157f
, 0.764706f
,
2052 0.313725f
, 0.0352941f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2053 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2054 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2055 0.215686f
, 0.67451f
, 0.886275f
, 0.992157f
, 0.992157f
, 0.992157f
,
2056 0.992157f
, 0.956863f
, 0.521569f
, 0.0431373f
, 0.0f
, 0.0f
,
2057 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2058 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2059 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.533333f
, 0.992157f
,
2060 0.992157f
, 0.992157f
, 0.831373f
, 0.529412f
, 0.517647f
, 0.0627451f
,
2061 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2062 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2063 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2064 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2065 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2066 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2067 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2068 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2069 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2070 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2071 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2072 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2073 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2074 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2075 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2076 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.0f
,
2077 0.0f
, 0.0f
, 0.0f
, 0.0f
2079 float expected
[] = {
2080 -0.836424f
, -0.857365f
, -1.62739f
, -1.62739f
, -0.836424f
, 5.40742f
,
2081 0.920853f
, -0.692567f
, -0.836424f
, -0.534405f
, -1.62739f
, -0.836424f
,
2082 1.32602f
, 1.36312f
, 0.112766f
, -0.836424f
, -0.192962f
, 1.56975f
,
2083 2.45777f
, 0.944414f
, -0.192962f
, -1.5519f
, -1.5519f
, -0.554006f
,
2084 -0.192962f
, 1.4231f
, -1.5519f
, -0.192962f
, 1.3661f
, -1.5519f
,
2085 -1.5519f
, -0.192962f
, -0.843708f
, -0.359025f
, -0.843708f
, -0.843708f
,
2086 -0.843708f
, 4.53065f
, 0.0429584f
, -0.796804f
, -0.843708f
, 0.3473f
,
2087 -0.843708f
, -0.843708f
, -0.114439f
, 3.14817f
, 0.0811934f
, -0.843708f
2090 0.119643f
, -0.237864f
, 0.0462892f
, 0.0502297f
, -0.0134528f
,
2091 0.146347f
, 0.153133f
, 0.0513307f
, 0.0752369f
, 0.0135557f
,
2092 -0.111434f
, 0.0941854f
, 0.0788362f
, 0.0299412f
, 0.111762f
,
2093 0.144066f
, 0.00431504f
, -0.0177954f
, 0.0738092f
, -0.0344215f
,
2094 0.0832582f
, 0.053989f
, -0.112691f
, 0.0962145f
, 0.0186525f
,
2095 -0.00660205f
, -0.111962f
, -0.126801f
, -0.231625f
, 0.17309f
,
2096 0.0748875f
, -0.179569f
, -0.00513812f
, -0.156579f
, -0.147322f
,
2097 0.184168f
, 0.189308f
, -0.200359f
, -0.0156733f
, 0.140649f
,
2098 0.0858496f
, -0.0263217f
, -0.0740749f
, -0.112563f
, 0.107528f
,
2099 0.0609729f
, -0.221625f
, 0.0769944f
, -0.00900815f
, -0.00136441f
,
2100 -0.0236521f
, -0.0418025f
, -0.00286299f
, 0.12241f
, 0.0964093f
,
2101 -0.0150897f
, 0.0532171f
, 0.0625916f
, 0.116939f
, 0.118024f
,
2102 0.161918f
, -0.00909767f
, 0.100897f
, -0.054563f
, -0.175179f
,
2103 -0.0687892f
, 0.00734235f
, 0.109833f
, -0.113776f
, 0.0595405f
,
2104 -0.170255f
, 0.0124815f
, -0.0363301f
, -0.0127038f
, 0.0445554f
,
2105 -0.0729894f
, 0.107428f
, -0.0341417f
, 0.132619f
, 0.00984557f
,
2106 -0.00443654f
, 0.202929f
, 0.0945134f
, 0.0148725f
, 0.00998574f
,
2107 -0.0226449f
, 0.0478197f
, -0.0793442f
, 0.0707599f
, -0.084225f
,
2108 0.0865795f
, 0.071104f
, -0.047894f
, 0.0838322f
, 0.0635493f
,
2109 -0.00370265f
, -0.157247f
, -0.0289622f
, -0.0590963f
, 0.13207f
,
2110 0.00468011f
, -0.0345372f
, 0.217939f
, 0.18861f
, -0.0290393f
,
2111 -0.0440664f
, 0.0126197f
, -0.129132f
, -0.124943f
, 0.0968156f
,
2112 -0.0853643f
, -0.182305f
, 0.00461618f
, -0.147095f
, -0.230282f
,
2113 0.00856019f
, 0.0278893f
, -0.0300229f
, 0.0417871f
, 0.0804717f
,
2114 -0.0768571f
, -0.0397085f
, -0.0601096f
, 0.100901f
, -0.0184926f
,
2115 0.0350673f
, 0.0971094f
, -0.0171837f
, -0.289644f
, -0.0899041f
,
2116 0.08998f
, -0.160319f
, -0.0195103f
, 0.0392167f
, -0.137864f
,
2117 -0.0136294f
, 0.0330886f
, -0.0409244f
, -0.092533f
, -0.0427934f
,
2118 -0.191144f
, -0.0969461f
, 0.112035f
, 0.138611f
, 0.128717f
,
2119 0.191184f
, 0.197462f
2121 float bias
[] = { 0.186703f
, 0.204358f
, -0.0230452f
};
2123 float bn_gamma
[] = { 1.32173f
, 1.26171f
, 1.21966f
};
2124 float bn_beta
[] = { -0.232595f
, -0.222652f
, -0.232209f
};
2125 float bn_mean
[] = { 0.329233f
, 0.199894f
, 0.12389f
};
2126 float bn_std
[] = { 0.311986f
, 0.189737f
, 0.247104f
};
2128 CNN_BATCHNORM_PARAMS bn_params
= {
2135 CNN_CONFIG cnn_config
= {
2165 CNN_THREAD_DATA thread_data
= { 1, nullptr };
2167 RunCNNTest(image_width
, image_height
, input
, expected
, &cnn_config
,
2168 image_width
, &thread_data
, MSE_FLOAT_TOL
);
2171 TEST_F(CNNTest
, TestMultithreading
) {
2172 int image_height
= 2;
2173 int image_width
= 2;
2174 int filter_height
= 3;
2175 int filter_width
= 3;
2185 -4, 2, -2, 0, -4, 4, -3, -3, -3, -1, 1, 0, -5, -3, 0, -5, 0, 0,
2186 -1, 0, 2, -5, 0, 1, 4, 2, 1, 0, -2, -1, -5, -3, 2, -2, 1, -5,
2196 float expected
[] = {
2197 2, 10, -8, -17, -24, 5, -15, 6, -5, -5, 7, -10, 4, 13, 9, -14,
2200 CNN_CONFIG cnn_config
= {
2230 CNN_THREAD_DATA thread_data
= { 1, nullptr };
2232 RunCNNTest(image_width
, image_height
, input
, expected
, &cnn_config
,
2233 image_width
, &thread_data
, MSE_FLOAT_TOL
);
2235 const AVxWorkerInterface
*const winterface
= aom_get_worker_interface();
2236 AVxWorker workers
[4];
2238 for (int i
= 0; i
< 4; ++i
) {
2239 winterface
->init(&workers
[i
]);
2242 thread_data
= { 4, workers
};
2244 RunCNNTest(image_width
, image_height
, input
, expected
, &cnn_config
,
2245 image_width
, &thread_data
, MSE_FLOAT_TOL
);
2247 for (int i
= 0; i
< 4; ++i
) {
2248 winterface
->end(&workers
[i
]);
2252 TEST_F(CNNTest
, TestMultiOutput
) {
2253 const int image_dim
= 8;
2254 const int image_ch
= 3;
2255 const int filter_dim
= 2;
2256 const int stride
= 2;
2257 const int num_filters
= 2;
2259 const float input_
[] = {
2260 1.7537929121f
, 0.134331551012f
, 0.123580039877f
, 0.957731845246f
,
2261 0.391006834217f
, 1.00699352042f
, -0.778177955829f
, -0.814166433059f
,
2262 -0.656374394915f
, 0.321967305228f
, -2.19455719176f
, 0.708035038966f
,
2263 0.409148822266f
, -0.318254408902f
, 0.152450211189f
, -0.250210793369f
,
2264 0.826811563186f
, 1.6804156584f
, 0.273626975978f
, 0.437936241887f
,
2265 -0.329935520167f
, -0.288761611645f
, 0.156937008304f
, 0.271054157295f
,
2266 -0.0224828854332f
, 1.70110336895f
, -0.989066699309f
, 1.30863131729f
,
2267 -0.165813705702f
, 0.00380178619265f
, -0.0837342367587f
, 0.760954783156f
,
2268 -0.413610373524f
, 1.17968204175f
, 0.720295719536f
, 0.308718974472f
,
2269 -1.10091337671f
, 0.693160033687f
, -0.0202862320697f
, 1.0221927503f
,
2270 -1.24521801881f
, -0.478501952308f
, -1.71648619442f
, -0.182571723636f
,
2271 0.339292649504f
, 2.0806519131f
, 0.967974033444f
, 0.175248672328f
,
2272 0.0658124561472f
, 0.795504169496f
, 0.750592557361f
, -1.46631013249f
,
2273 -1.79052846838f
, -1.03672179515f
, -0.841985521653f
, 1.20995011489f
,
2274 0.140859718215f
, -0.651552622661f
, 0.451065110806f
, 1.1189443693f
,
2275 0.100213260593f
, -0.834076868118f
, -1.28734321611f
, 1.22064420095f
,
2276 -0.364143084361f
, 0.750961509335f
, -0.888689074553f
, -0.8253547106f
,
2277 -1.21800999027f
, -0.966670603566f
, 1.37384014741f
, 0.47281264834f
,
2278 -0.420416235531f
, 0.520163906493f
, 0.501296589423f
, 1.53418976951f
,
2279 0.715234751485f
, 0.644551588907f
, 0.0763504863375f
, -0.0018541943723f
,
2280 0.322853189656f
, -0.795099723224f
, -0.125177096675f
, 1.4476577471f
,
2281 -0.585888410088f
, -1.44391754955f
, -0.610543221933f
, -0.221859179799f
,
2282 0.252060200774f
, -0.86287169623f
, -0.0350246229157f
, 1.0932311997f
,
2283 0.899464648842f
, -0.468806951704f
, -0.300861137168f
, 1.15776414206f
,
2284 1.03268544738f
, -0.171579585622f
, -0.179136557119f
, -0.354091003368f
,
2285 -0.612298249394f
, -1.20237379258f
, 1.54604109659f
, 0.130664370287f
,
2286 0.885225111868f
, 1.0362799581f
, 0.980561720868f
, -0.619379186999f
,
2287 -1.33818929924f
, -0.237233737961f
, -1.89335425073f
, 0.567821011321f
,
2288 0.862420368465f
, -1.37380916821f
, 0.352190056666f
, 0.611261516274f
,
2289 0.393237747152f
, 0.894686247967f
, 0.190405182149f
, 0.264872662911f
,
2290 -0.0657009133797f
, 0.0580512653493f
, -0.401825294366f
, 0.4106081318f
,
2291 0.49484512188f
, -0.0751103149442f
, -1.43243736382f
, 1.79855656009f
,
2292 -1.1075351975f
, 0.000354882733011f
, -0.950716438608f
, 1.27129831688f
,
2293 1.00495189838f
, 0.110358656713f
, 1.08315032822f
, -0.972676676218f
,
2294 -0.0757668962831f
, 1.88932045165f
, -0.0672638136275f
, 0.425913010161f
,
2295 -0.781540372017f
, 0.976000248609f
, 0.687218504122f
, 1.31374513445f
,
2296 -0.932658930672f
, -1.25339468479f
, 0.422071294078f
, -0.24189927912f
,
2297 0.216906604642f
, -1.88720997548f
, 1.99252872889f
, 0.353943735777f
,
2298 0.737434784132f
, -1.17848645017f
, 1.70424254896f
, 0.775297112968f
,
2299 -0.516392797501f
, 0.398130609129f
, 0.737248101457f
, 0.166282500886f
,
2300 1.24699015468f
, 0.47116183125f
, 1.19091180182f
, -0.372695424578f
,
2301 0.219773209389f
, -0.829467838962f
, -0.52533122724f
, 1.98707754595f
,
2302 0.553692606972f
, -0.933228902369f
, 1.55427751643f
, -1.08813399144f
,
2303 -0.325686682094f
, 0.205091443796f
, -1.70381666435f
, 0.466465327942f
,
2304 1.73126863447f
, -0.939133672634f
, 1.48318077459f
, -0.599414038168f
,
2305 -1.1583078687f
, 0.518116190201f
, 0.133571482458f
, 0.84958342672f
,
2306 1.02205000597f
, -0.0772082009087f
, -1.69567503859f
, 1.4697939436f
,
2307 1.67813743122f
, -0.627911582938f
, 0.131380509137f
, -1.35717850726f
,
2309 const float *input
[3] = { input_
, &input_
[image_dim
* image_dim
],
2310 &input_
[2 * image_dim
* image_dim
] };
2312 const float bias
[] = { 0.0f
, 0.0f
};
2314 const float weights_1
[] = {
2315 -0.489547413618f
, 0.141916424749f
, -0.279286485585f
, -0.115322211094f
,
2316 0.299572786936f
, 0.205289980785f
, -0.536254480088f
, -0.253626313744f
,
2317 -0.422883815849f
, -0.169702966298f
, -0.540104704793f
, 0.495319646763f
,
2318 0.298799079422f
, -0.10054550901f
, -0.306085047056f
, 0.171061886165f
,
2319 -0.108058703878f
, -0.410734629888f
, -0.0640674673049f
, -0.386524840979f
,
2320 -0.157203423678f
, -0.362138920529f
, -0.216206085209f
, 0.147502517971f
,
2323 const float weights_2
[] = {
2324 0.207580604357f
, 0.480821146263f
, -0.29111909562f
, 0.47422567493f
,
2325 0.206892553253f
, -0.235067084092f
, 0.354516800602f
, -0.212399370252f
,
2326 -0.419071343731f
, -0.050350731631f
, -0.0516457320279f
, -0.0359310500731f
,
2327 0.567044864811f
, -0.060341127522f
, 0.0501464839637f
, -0.437785677916f
,
2330 const float weights_3
[] = {
2331 -0.0690452401448f
, -0.356657338763f
, -0.219464031809f
, 0.551288365843f
,
2332 0.181372090853f
, -0.00245268542109f
, 0.409000696276f
, -0.593209108763f
,
2333 0.587352566749f
, -0.243720660227f
, 0.266232713887f
, -0.00439285245097f
,
2334 0.252883228305f
, 0.152646192631f
, 0.0918944932026f
, 0.398853715057f
,
2337 const float weights_4
[] = {
2338 0.207560791573f
, 0.194201350401f
, 0.227802322443f
, 0.206533663345f
,
2339 0.0557331066805f
, 0.0224159800424f
, -0.143939197467f
, -0.27703361602f
,
2340 0.130643888389f
, -0.269456557461f
, 0.186242862864f
, -0.162879944774f
,
2341 -0.145503996718f
, -0.0768822987581f
, -0.203127976359f
, -0.238119922873f
,
2342 -0.258806479994f
, 0.0357957680385f
, -0.1027606976f
, -0.287920082345f
,
2343 0.189047820993f
, 0.250711538481f
, -0.272815714175f
, -0.0431449742024f
,
2344 0.207261230996f
, -0.0396472677451f
, 0.131236557412f
, 0.174291832499f
,
2345 -0.251515885765f
, -0.107164007499f
, 0.185824534748f
, -0.00561585838161f
,
2346 0.273393799578f
, -0.139563699075f
, -0.263922456031f
, -0.118859844081f
,
2347 0.109230982597f
, -0.170170294794f
, 0.0123025648515f
, -0.0839368964355f
,
2348 -0.0774058234297f
, 0.255847138286f
, -0.208430879637f
, 0.279170114319f
,
2349 -0.272890330712f
, -0.217725903006f
, -0.295923275459f
, -0.17008723953f
,
2350 -0.284281803405f
, 0.281406323629f
, 0.266910044663f
, -0.209963914338f
,
2351 0.271980962964f
, 0.142013581699f
, -0.143896509026f
, -0.290509242975f
,
2352 -0.305768180935f
, 0.196902832117f
, -0.090424189662f
, -0.147460802346f
,
2353 0.217722016651f
, 0.12353848977f
, -0.169177363577f
, -0.0454230918512f
,
2356 const float expected_0
[] = {
2357 -2.04858441055f
, -2.12883075791f
, -0.045177363807f
, 0.763949675768f
,
2358 -0.544361512821f
, -1.58123168032f
, 1.89319847039f
, 0.16859080901f
,
2359 -1.16023321135f
, -0.396988107751f
, 1.76637090744f
, -1.40434786514f
,
2360 0.908227575669f
, 0.817064817605f
, 0.215631134908f
, -0.848605613428f
,
2361 -0.106756747018f
, 0.0193027166685f
, 0.801345615113f
, -0.395407237598f
,
2362 -1.79983795658f
, -1.73054496242f
, 0.0584392594454f
, -0.388786095569f
,
2363 -0.237269619354f
, 0.000843578271263f
, -1.24043512104f
, 0.487839445893f
,
2364 -0.394259726605f
, 0.559632843424f
, -0.527224052291f
, -1.53792340282f
,
2367 const float expected_1
[] = {
2368 0.0f
, 0.0f
, 0.0f
, 0.0f
, 0.4057888292f
, 0.325309571755f
,
2369 0.0f
, 1.22013465602f
,
2372 const float expected_2
[] = {
2377 const float expected_3
[] = {
2384 const float *expected
[] = { expected_0
, expected_1
, expected_2
, expected_3
};
2386 CNN_CONFIG cnn_config
= {
2395 image_ch
, // in_channels
2396 filter_dim
, // filter_width
2397 filter_dim
, // filter_height
2398 num_filters
, // out_channels
2399 stride
, // skip_width
2400 stride
, // skip_height
2402 weights_1
, // weights
2404 PADDING_SAME_ZERO
, // pad
2408 BRANCH_OUTPUT
, // branch_copy_type
2409 BRANCH_NOC
, // branch_combine_type
2410 { 2, 0, 0 }, // branch_config
2415 num_filters
, // in_channels
2416 filter_dim
, // filter_width
2417 filter_dim
, // filter_height
2418 num_filters
, // out_channels
2419 stride
, // skip_width
2420 stride
, // skip_height
2422 weights_2
, // weights
2424 PADDING_SAME_ZERO
, // pad
2428 BRANCH_NO_COPY
, // branch_copy_type
2429 BRANCH_NOC
, // branch_combine_type
2430 {}, // branch_config
2435 num_filters
, // in_channels
2436 filter_dim
, // filter_width
2437 filter_dim
, // filter_height
2438 num_filters
, // out_channels
2439 stride
, // skip_width
2440 stride
, // skip_height
2442 weights_3
, // weights
2444 PADDING_SAME_ZERO
, // pad
2448 BRANCH_NO_COPY
, // branch_copy_type
2449 BRANCH_NOC
, // branch_combine_type
2450 {}, // branch_config
2455 num_filters
, // in_channels
2456 2 * filter_dim
, // filter_width
2457 2 * filter_dim
, // filter_height
2458 num_filters
, // out_channels
2459 2 * stride
, // skip_width
2460 2 * stride
, // skip_height
2462 weights_4
, // weights
2464 PADDING_VALID
, // pad
2468 BRANCH_NO_COPY
, // branch_copy_type
2469 BRANCH_CAT
, // branch_combine_type
2470 { 0, 0, 1 }, // branch_config
2477 CNN_THREAD_DATA thread_data
= { 1, nullptr };
2479 const int num_outputs
= 4;
2480 const int output_chs
[4] = { filter_dim
, filter_dim
, filter_dim
,
2482 const int output_dims
[4] = { 4, 2, 1, 1 };
2483 const int output_sizes
[4] = {
2484 output_chs
[0] * output_dims
[0] * output_dims
[0],
2485 output_chs
[1] * output_dims
[1] * output_dims
[1],
2486 output_chs
[2] * output_dims
[2] * output_dims
[2],
2487 output_chs
[3] * output_dims
[3] * output_dims
[3],
2489 float *const output_
= (float *)aom_malloc(
2491 (output_sizes
[0] + output_sizes
[1] + output_sizes
[2] + output_sizes
[3]));
2492 ASSERT_NE(output_
, nullptr);
2493 float *output
[CNN_MAX_CHANNELS
] = { nullptr };
2495 float *output_ite
= output_
;
2496 for (int output_idx
= 0; output_idx
< num_outputs
; output_idx
++) {
2497 for (int channel
= 0; channel
< output_chs
[output_idx
]; ++channel
) {
2498 output
[ch_ite
++] = output_ite
;
2499 output_ite
+= output_dims
[output_idx
] * output_dims
[output_idx
];
2502 CNN_MULTI_OUT output_struct
= { num_outputs
, output_chs
, output_dims
,
2505 RunMultiOutCNNTest(input
, image_dim
, image_dim
, image_dim
, &cnn_config
,
2506 &thread_data
, &output_struct
, expected
, MSE_FLOAT_TOL
);
2513 typedef void (*CNNConvolveNoMaxpoolPaddingValidFunc
)(
2514 const float **input
, int in_width
, int in_height
, int in_stride
,
2515 const CNN_LAYER_CONFIG
*layer_config
, float **output
, int out_stride
,
2516 int start_idx
, int cstep
, int channel_step
);
2518 typedef libaom_test::FuncParam
<CNNConvolveNoMaxpoolPaddingValidFunc
>
2519 CNNConvolveTestFuncs
;
2521 class CNNConvolveTest
: public ::testing::TestWithParam
<CNNConvolveTestFuncs
> {
2523 void SetUp() override
{ params_
= GetParam(); }
2525 void RunCNNConvolveSetup(int run_times
) {
2529 const CNN_CONFIG
*cnn_config
= &av1_intra_mode_cnn_partition_cnn_config
;
2531 for (int layer
= 0; layer
< cnn_config
->num_layers
; ++layer
) {
2532 int out_width
= 0, out_height
= 0;
2533 int in_size
= in_width
* in_height
;
2534 // Get current layer output width and height.
2535 av1_find_cnn_layer_output_size(in_height
, in_width
,
2536 &cnn_config
->layer_config
[layer
],
2537 &out_width
, &out_height
);
2539 int out_size
= out_width
* out_height
;
2540 float *input
[20], *output_ref
[20], *output_mod
[20];
2543 (float *)aom_malloc(sizeof(*input_data
) * in_size
*
2544 cnn_config
->layer_config
[layer
].in_channels
);
2545 float *temp_ptr
= input_data
;
2546 ASSERT_NE(temp_ptr
, nullptr);
2547 for (int i
= 0; i
< cnn_config
->layer_config
[layer
].in_channels
; ++i
) {
2548 input
[i
] = temp_ptr
;
2549 for (int j
= 0; j
< in_size
; j
++) {
2550 *(temp_ptr
++) = ((float)rng_
.Rand31() - (1 << 30)) / (1u << 31);
2554 float *out_data_ref
= (float *)aom_calloc(
2555 sizeof(*out_data_ref
),
2556 out_size
* cnn_config
->layer_config
[layer
].out_channels
);
2557 ASSERT_NE(out_data_ref
, nullptr);
2558 float *out_data_mod
= (float *)aom_calloc(
2559 sizeof(*out_data_mod
),
2560 out_size
* cnn_config
->layer_config
[layer
].out_channels
);
2561 ASSERT_NE(out_data_mod
, nullptr);
2562 float *temp_ptr1
= out_data_ref
;
2563 float *temp_ptr2
= out_data_mod
;
2564 for (int i
= 0; i
< cnn_config
->layer_config
[layer
].out_channels
; ++i
) {
2565 output_ref
[i
] = temp_ptr1
;
2566 output_mod
[i
] = temp_ptr2
;
2567 temp_ptr1
+= out_size
;
2568 temp_ptr2
+= out_size
;
2571 RunCNNConvolveTest(input
, in_width
, in_height
, out_size
,
2572 &cnn_config
->layer_config
[layer
], 0, 1, run_times
,
2573 layer
, output_ref
, output_mod
, out_width
);
2575 // Set current layer output width and height as next layer input width and
2577 in_width
= out_width
;
2578 in_height
= out_height
;
2580 aom_free(input_data
);
2581 aom_free(out_data_ref
);
2582 aom_free(out_data_mod
);
2586 void RunCNNConvolveTest(float **input
, int in_width
, int in_height
,
2587 int out_size
, const CNN_LAYER_CONFIG
*layer_config
,
2588 int start_idx
, int step
, int run_times
, int layer
,
2589 float **output_ref
, float **output_mod
,
2591 const int cstep
= layer_config
->in_channels
* layer_config
->out_channels
;
2592 const int channel_step
= AOMMAX(step
, 1);
2593 aom_usec_timer timer
;
2594 aom_usec_timer_start(&timer
);
2595 for (int i
= 0; i
< run_times
; ++i
) {
2596 params_
.ref_func((const float **)input
, in_width
, in_height
, in_width
,
2597 layer_config
, output_ref
, out_stride
, start_idx
, cstep
,
2600 aom_usec_timer_mark(&timer
);
2601 const double time1
= static_cast<double>(aom_usec_timer_elapsed(&timer
));
2603 aom_usec_timer_start(&timer
);
2604 for (int i
= 0; i
< run_times
; ++i
) {
2605 params_
.tst_func((const float **)input
, in_width
, in_height
, in_width
,
2606 layer_config
, output_mod
, out_stride
, start_idx
, cstep
,
2609 aom_usec_timer_mark(&timer
);
2610 const double time2
= static_cast<double>(aom_usec_timer_elapsed(&timer
));
2612 if (run_times
> 1) {
2613 printf("layer : %d \n", layer
);
2614 printf("%7.2f/%7.2fns (%3.2f)\n", time1
, time2
, time1
/ time2
);
2616 for (int channel
= 0; channel
< layer_config
->out_channels
; ++channel
) {
2617 const float *buf_ref
= output_ref
[channel
];
2618 const float *buf_mod
= output_mod
[channel
];
2620 for (int i
= 0; i
< out_size
; ++i
) {
2621 if (buf_ref
[i
] < CNN_CONVOLVE_PIXELWISE_FLOAT_TOL
) {
2622 ASSERT_LE(buf_ref
[i
], CNN_CONVOLVE_PIXELWISE_FLOAT_TOL
)
2623 << "Reference output was near-zero, test output was not ("
2624 << buf_mod
[i
] << ")";
2626 const float error
= buf_ref
[i
] - buf_mod
[i
];
2627 const float relative_error
= fabsf(error
/ buf_ref
[i
]);
2628 ASSERT_LE(relative_error
, CNN_CONVOLVE_PIXELWISE_FLOAT_TOL
)
2629 << " channel " << channel
<< " pixel " << i
<< ": "
2630 << buf_ref
[i
] << "/" << buf_mod
[i
] << std::endl
;
2638 CNNConvolveTestFuncs params_
;
2639 libaom_test::ACMRandom rng_
;
2641 GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(CNNConvolveTest
);
2643 TEST_P(CNNConvolveTest
, CheckOutput
) { RunCNNConvolveSetup(1); }
2645 TEST_P(CNNConvolveTest
, DISABLED_Speed
) { RunCNNConvolveSetup(100000); }
2647 #if HAVE_AVX2 && !CONFIG_EXCLUDE_SIMD_MISMATCH
2648 INSTANTIATE_TEST_SUITE_P(AVX2
, CNNConvolveTest
,
2649 ::testing::Values(CNNConvolveTestFuncs(
2650 &av1_cnn_convolve_no_maxpool_padding_valid_c
,
2651 &av1_cnn_convolve_no_maxpool_padding_valid_avx2
)));
2655 INSTANTIATE_TEST_SUITE_P(NEON
, CNNConvolveTest
,
2656 ::testing::Values(CNNConvolveTestFuncs(
2657 &av1_cnn_convolve_no_maxpool_padding_valid_c
,
2658 &av1_cnn_convolve_no_maxpool_padding_valid_neon
)));