Bug 1882465 - Update .hg-annotate-ignore-revs and .git-blame-ignore-revs to reflect...
[gecko.git] / third_party / aom / test / cnn_test.cc
blobe5114b56ce4a7f66bcd82c53a9f8d947fe674539
1 /*
2 * Copyright (c) 2019, Alliance for Open Media. All rights reserved
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
12 #include <assert.h>
13 #include <math.h>
14 #include <stdio.h>
16 #include "third_party/googletest/src/googletest/include/gtest/gtest.h"
18 #include "config/av1_rtcd.h"
20 #include "aom_ports/aom_timer.h"
21 #include "av1/encoder/cnn.h"
22 #include "av1/encoder/partition_cnn_weights.h"
23 #include "test/acm_random.h"
24 #include "test/function_equivalence_test.h"
25 #include "test/util.h"
27 #define SQR(x) ((x) * (x))
29 // Best possible pixelwise guaranteed precision given each float has at most
30 // 3 specified decimals.
31 #define PIXELWISE_FLOAT_TOL 1E-2
33 #define MSE_FLOAT_TOL 1E-6
34 #define MSE_INT_TOL 0
36 // CNN convolve pixelwise error threshold for functional equivalence.
37 #define CNN_CONVOLVE_PIXELWISE_FLOAT_TOL 1E-3f
39 namespace {
41 class CNNTest : public ::testing::Test {
42 protected:
43 static void RunCNNTest(int image_width, int image_height, const float *input,
44 const float *expected, const CNN_CONFIG *cnn_config,
45 int in_stride, CNN_THREAD_DATA *thread_data,
46 double tolerance) {
47 int out_width, out_height, out_channels;
48 av1_find_cnn_output_size(image_width, image_height, cnn_config, &out_width,
49 &out_height, &out_channels);
51 const int out_size = out_width * out_height;
52 const int out_stride = out_width;
54 float *output_ =
55 (float *)aom_malloc(sizeof(*output_) * out_size * out_channels);
56 ASSERT_NE(output_, nullptr);
57 float *output[CNN_MAX_CHANNELS] = { nullptr };
58 for (int channel = 0; channel < out_channels; ++channel) {
59 output[channel] = output_ + (channel * out_size);
61 const int num_outputs = 1;
62 const int output_chs[1] = { out_channels };
63 const int output_strides[1] = { out_stride };
64 CNN_MULTI_OUT output_struct = { num_outputs, output_chs, output_strides,
65 output };
67 RunMultiOutCNNTest(&input, image_width, image_height, in_stride, cnn_config,
68 thread_data, &output_struct, &expected, tolerance);
70 aom_free(output_);
73 static void RunMultiOutCNNTest(const float **input, int image_width,
74 int image_height, int in_stride,
75 const CNN_CONFIG *cnn_config,
76 CNN_THREAD_DATA *thread_data,
77 CNN_MULTI_OUT *output, const float **expected,
78 double tolerance) {
79 const int num_outputs = output->num_outputs;
80 const int *output_chs = output->output_channels;
82 int *out_widths = (int *)aom_calloc(num_outputs, sizeof(*out_widths));
83 int *out_heights = (int *)aom_calloc(num_outputs, sizeof(*out_heights));
84 int *not_used = (int *)aom_calloc(num_outputs, sizeof(*not_used));
85 ASSERT_NE(out_widths, nullptr);
86 ASSERT_NE(out_heights, nullptr);
87 ASSERT_NE(not_used, nullptr);
89 av1_find_cnn_output_size(image_width, image_height, cnn_config, out_widths,
90 out_heights, not_used);
91 ASSERT_TRUE(av1_cnn_predict(input, image_width, image_height, in_stride,
92 cnn_config, thread_data, output));
94 int channel_offset = 0;
95 for (int output_idx = 0; output_idx < num_outputs; output_idx++) {
96 const float *expected_out = expected[output_idx];
97 const int curr_output_chs = output_chs[output_idx];
98 const int out_size = out_widths[output_idx] * out_heights[output_idx];
100 double mse = 0;
101 int expected_ite = 0;
102 for (int channel = 0; channel < curr_output_chs; ++channel) {
103 const float *buf_out = output->output_buffer[channel_offset];
105 for (int i = 0; i < out_size; ++i) {
106 EXPECT_NEAR(expected_out[expected_ite], buf_out[i],
107 PIXELWISE_FLOAT_TOL)
108 << " output " << output_idx << " channel " << channel << " pixel "
109 << expected_ite % out_size << ": " << expected_out[expected_ite]
110 << "/" << buf_out[i] << std::endl;
111 mse += SQR(expected_out[expected_ite] - buf_out[i]);
112 expected_ite++;
115 channel_offset++;
117 mse /= (out_size * curr_output_chs);
118 EXPECT_LE(mse, tolerance) << " output " << output_idx << std::endl;
121 aom_free(out_widths);
122 aom_free(out_heights);
123 aom_free(not_used);
126 static void AssignLayerWeightsBiases(CNN_CONFIG *cnn_config, float *weights,
127 float *bias) {
128 size_t weight_offset = 0;
129 size_t bias_offset = 0;
130 for (int layer = 0; layer < cnn_config->num_layers; ++layer) {
131 CNN_LAYER_CONFIG *layer_config = &cnn_config->layer_config[layer];
132 layer_config->weights = weights + weight_offset;
133 layer_config->bias = bias + bias_offset;
134 weight_offset += layer_config->filter_width *
135 layer_config->filter_height * layer_config->in_channels *
136 layer_config->out_channels;
137 bias_offset += layer_config->out_channels;
139 ASSERT_NE(layer_config->weights, nullptr);
140 ASSERT_NE(layer_config->bias, nullptr);
145 } // namespace
147 TEST_F(CNNTest, TestMultilayerConvolution) {
148 int image_height = 16;
149 int image_width = 16;
150 int filter_height = 5;
151 int filter_width = 4;
153 float input[] = {
154 -3, 1, -3, 2, -2, -2, 2, -2, 1, -2, -3, 1, 2, 2, 2, -2, 0, 1, -1,
155 -3, -1, -1, 1, 0, -3, 1, 0, -1, 1, 0, 0, -3, -3, -3, 0, 2, 1, -1,
156 2, 0, 1, -3, -1, 2, 2, 1, -2, 0, -1, 0, -2, -2, -1, 1, 0, 0, 0,
157 -2, -2, -2, 1, 1, -2, 1, 1, -2, -2, 1, -2, -1, -2, -3, 2, -3, -1, 1,
158 0, -2, -2, -2, 1, -2, -2, -1, -1, 2, 2, 2, -1, 1, -3, -3, 0, 2, 0,
159 2, 1, -3, -3, 1, 2, 2, 1, -2, -3, 0, -3, 0, -3, -2, 0, 1, 1, 0,
160 -3, 2, -1, 2, 1, 0, 1, -2, 1, -1, -1, 2, 0, -2, -3, 1, 1, -2, -1,
161 -3, -3, -1, 0, -3, -2, 0, 0, 1, 0, -3, -2, -1, 1, 0, 2, 1, 0, -3,
162 -2, -3, -3, -1, 0, -2, 2, -1, -3, 0, -1, -1, 2, 0, -3, -2, -1, 0, 0,
163 1, -2, 1, 2, 1, 2, 2, -3, 2, -1, 0, 0, -1, 0, 2, 2, -1, 2, -2,
164 1, 1, -3, -3, 1, -1, -1, -2, 2, -2, -2, 2, -1, -3, 2, -3, 1, -1, -1,
165 -3, 1, -1, 1, 0, -3, -3, 1, -3, -3, 0, 2, 2, -2, -1, 2, 0, 2, 1,
166 -1, -3, 0, 0, -1, -1, 1, 0, 2, 0, -3, 2, 1, 0, 1, -3, 2, -3, -3,
167 -1, -3, -3, 2, 0, 2, -2, 1, -1,
170 float weights[] = {
171 -2, 2, -2, 2, -1, -3, 2, 2, 0, 0, -3, -1, -2, -3, 1, -1, 0, 0, 0,
172 2, -2, 2, -2, -3, 1, 1, 1, -3, -1, 0, 1, 2, -2, 0, -1, -3, -1, -2,
173 2, -3, -3, 1, -2, -3, 0, 2, 1, -3, -3, -1, -3, -2, -1, -3, -1, -3, -2,
174 -1, -3, -1, -2, -2, -3, 2, 0, -3, 0, -3, -3, 1, -3, -1, 0, -1, 1, 1,
175 -1, 1, -2, 0, 2, 0, -3, 1, -1, -1, 2, 0, 1, -3, -3, 1, 2, -3, -3,
176 1, -3, 2, 0, -3, 1, 2, 2, -2, -1, -2, 1, 1, 0, -2, -2, 1, 2, -1,
177 -3, 1, -2, 2, -3, -2, -3, 2, 1, 0, -2, 0, 1, -3, 2, -2, -2, 0, 2,
178 -3, 2, 0, 0, 1, -2, 1, 1, -2, -1, -2, 1, -2, 0, -2, -2, 0, -1, -1,
179 -3, -3, -3, 1, -3, -2, 2, -1, 2, 0, 2, -2, 2, -2, 1, -3, -3, -1, 0,
180 2, 2, 1, -1, -3, -1, -3, 2, 1, -2, 0, -3, -1, -3, -1, 2, 1, 0, 2,
181 -1, 1, 0, 1, 2, -1, -2, 2, 1, -3, -1, -3, 0, 1, -2, 0, -2, -3, 0,
182 -2, 2, 2, 0, 0, 2, -3, 2, -3, -2, 1, 2, -3, -3, -1, -3, 0, -3, -3,
183 -2, -2, -2, 0, 0, 1, 0, 0, -1, 0, 0, -3, 0, -3, -1, -2, 1, -2, -1,
184 2, -2, 0, 0, 1, 0, -2, -1, 0, -3, 1, 0, -1, -3, 1, -1, 1, -1, -3,
185 1, 0, 1, 1, -1, 2, 2, 0, 0, 1, -3, 2, -2, -2, -3, -2, -1, -2, 2,
186 0, 2, -2, -3, -1, -3, 2, 2, -1, 2, 2, -1, 0, -3, 1,
189 float bias[] = {
190 1, -1, 0, 1, 1, 1, -2,
193 float expected_same[] = {
194 -1125, 2926, 6406, 631, -1244, 97, -1454, 2526, 1065, 3292, 3464,
195 2553, -330, 532, 1038, 1182, -402, 3758, 3392, 9854, 4365, 1408,
196 4736, 3134, 3838, 2409, 3221, 4350, 6750, 4045, 815, 1188, 2959,
197 9802, 9590, 4572, 5740, 4253, 1701, 7974, 7012, 6854, 7093, 3907,
198 4539, 3886, 4267, 3505, 465, 7824, 9219, 10026, 7968, 957, 2295,
199 5594, 10811, 9641, 5950, 10043, 8783, 3132, 1421, 1110, 4108, 13929,
200 10660, -84, -61, 3932, -180, 6811, 13393, 15147, 15640, 9337, 6961,
201 3808, 1604, 1398, 1047, 6739, 10144, 6517, 4698, 2678, 7389, 2595,
202 5248, 12075, 11272, 13951, 8820, 1090, 2199, 2206, 2788, 12116, 6683,
203 2612, -291, 3183, 9414, 12316, 14524, 12333, 13208, 7832, 4664, 4657,
204 3534, 1298, -666, 4250, 7707, 9103, 5760, 688, 9571, 15782, 14203,
205 14878, 17339, 14684, 8690, 5671, 875, 1429, 1531, 6173, 2984, 5558,
206 2996, 7928, 6733, 16117, 15262, 12757, 7980, 3923, 4795, 5973, 2051,
207 455, -1922, 1816, 5906, 3321, 10908, 10910, 7377, 12204, 12809, 11195,
208 7451, 6666, 74, -1645, -35, -391, 3813, 7324, 892, 1656, 6095,
209 12193, 14648, 12156, 14663, 10251, 10325, 7821, 3925, 323, 697, 442,
210 1324, 4669, 7002, 5485, 5171, 5086, 10582, 11053, 9709, 11353, 8543,
211 5256, 2873, 235, -628, 1496, 1878, -867, 3420, 6865, 5937, 10182,
212 13277, 10069, 10789, 5998, 624, -2082, 4417, 1258, -1080, -819, -1430,
213 1033, 5220, 6335, 8471, 8980, 11908, 14430, 12584, 8404, 1576, -803,
214 985, 1481, 1367, -193, 873, 3684, 2288, 6676, 9477, 11155, 9602,
215 9707, 10507, 4739, 3174, -575, -178, 3002, 1710, 423, -477, 554,
216 3088, 2029, 5113, 5000, 3771, 6090, 5365, 1185, 2855, 399, -312,
217 -1577, 176, 955,
220 float expected_replicate[] = {
221 13768, 13528, 12999, 6906, 4618, 4043, 2611, 9955, 6685, 4776, 2753,
222 1036, 3063, 4544, 5183, 7349, 12451, 12501, 9131, 12753, 8908, 4058,
223 6299, 7542, 7115, 3307, 3360, 3543, 9754, 7808, 5991, 9019, 14320,
224 14919, 12492, 6871, 7373, 3336, 2085, 10604, 9377, 6882, 5009, 3103,
225 6220, 6278, 7588, 10196, 11045, 11563, 11842, 11911, 8279, 2030, 1858,
226 6368, 12123, 9909, 6347, 10345, 9365, 4038, 1673, 3051, 16492, 16649,
227 12276, 408, -301, 4122, -654, 7864, 14038, 15279, 15315, 9744, 8243,
228 5298, 746, 380, 9824, 9124, 10895, 6640, 4712, 2669, 6980, 2759,
229 5385, 12345, 11336, 13129, 8600, 2370, 3682, 5219, 12407, 13123, 6784,
230 2612, -291, 3183, 9414, 12316, 14524, 12333, 13397, 7543, 3916, 4153,
231 4477, 4314, 7983, 8418, 9163, 9103, 5760, 688, 9571, 15782, 14203,
232 14878, 17718, 14570, 7940, 6642, 5094, 7133, 9964, 10219, 3224, 5558,
233 2996, 7928, 6733, 16117, 15262, 12757, 7958, 4401, 5187, 5476, 5529,
234 6055, 2206, 3909, 6015, 3321, 10908, 10910, 7377, 12204, 12809, 11195,
235 6967, 6840, 481, -1600, 274, 1, 10373, 8514, 1123, 2117, 6758,
236 12736, 16223, 13585, 15988, 11771, 10600, 7918, 4156, 2840, 3111, 3287,
237 6359, 7652, 8813, 6530, 6967, 7789, 13671, 13990, 13247, 13241, 9836,
238 5251, 3024, 2313, 1834, 4187, 2637, -1312, 2139, 7378, 7665, 11933,
239 15591, 15314, 15678, 9531, 2820, -1516, 3400, 1314, 22, 363, -2896,
240 -898, 5906, 7308, 10650, 12975, 16978, 20370, 18817, 12381, 4118, -861,
241 -137, 236, 1802, 1632, -350, 2334, 3400, 8680, 14064, 18216, 18675,
242 21765, 22871, 11491, 4937, -1555, -11, 1669, 2392, 3265, -5254, -217,
243 5001, 8063, 13444, 18884, 19706, 22794, 21064, 9545, 6689, -7, 289,
244 -2021, 504, 2347,
247 float expected_valid[] = {
248 2612, -291, 3183, 9414, 12316, 14524, 12333, 9103, 5760, 688,
249 9571, 15782, 14203, 14878, 5558, 2996, 7928, 6733, 16117, 15262,
250 12757, 3321, 10908, 10910, 7377, 12204, 12809, 11195,
253 CNN_CONFIG cnn_config = { 3,
261 filter_width,
262 filter_height,
267 nullptr,
268 nullptr,
269 PADDING_SAME_ZERO,
270 NONE,
273 BRANCH_NO_COPY,
274 BRANCH_NOC,
281 filter_width,
282 filter_height,
287 nullptr,
288 nullptr,
289 PADDING_SAME_ZERO,
290 NONE,
293 BRANCH_NO_COPY,
294 BRANCH_NOC,
301 filter_width,
302 filter_height,
307 nullptr,
308 nullptr,
309 PADDING_SAME_ZERO,
310 NONE,
313 BRANCH_NO_COPY,
314 BRANCH_NOC,
319 } };
321 // Weights and biases need to be specified separately because
322 // of the offset.
323 AssignLayerWeightsBiases(&cnn_config, weights, bias);
325 CNN_THREAD_DATA thread_data = { 1, nullptr };
327 RunCNNTest(image_width, image_height, input, expected_same, &cnn_config,
328 image_width, &thread_data, MSE_INT_TOL);
330 for (int i = 0; i < cnn_config.num_layers; ++i) {
331 cnn_config.layer_config[i].pad = PADDING_SAME_REPLICATE;
334 RunCNNTest(image_width, image_height, input, expected_replicate, &cnn_config,
335 image_width, &thread_data, MSE_INT_TOL);
337 for (int i = 0; i < cnn_config.num_layers; ++i) {
338 cnn_config.layer_config[i].pad = PADDING_VALID;
341 RunCNNTest(image_width, image_height, input, expected_valid, &cnn_config,
342 image_width, &thread_data, MSE_INT_TOL);
345 TEST_F(CNNTest, TestRELUSingleLayer) {
346 int image_width = 8;
347 int image_height = 8;
348 int filter_height = 5;
349 int filter_width = 4;
350 float input[] = {
351 0, -2, -3, 1, -1, 2, -2, 1, -3, -1, 0, 1, -2, -3, -2, -2,
352 1, -3, 2, -3, -1, -1, 2, 0, -2, -3, 0, -2, -3, 1, -1, -1,
353 2, -2, 0, -2, -3, -3, 1, 1, -1, 1, 0, 1, -3, 0, 2, 2,
354 0, -3, 1, -3, 2, -2, 1, -1, -1, -2, -3, -2, -1, -3, -2, -1,
356 float expected_same[] = {
357 9, 0, 1, 1, 0, 3, 0, 19, 0, 12, 10, 0, 0, 0, 5, 0,
358 0, 18, 21, 7, 19, 4, 3, 0, 0, 9, 16, 0, 11, 16, 0, 11,
359 12, 2, 0, 11, 0, 16, 6, 0, 8, 22, 13, 10, 12, 0, 0, 0,
360 0, 1, 2, 12, 29, 6, 10, 0, 13, 0, 0, 5, 8, 10, 0, 0,
362 float expected_replicate[] = {
363 18, 17, 12, 2, 0, 0, 5, 11, 0, 17, 22, 6, 0, 0, 17, 0,
364 0, 18, 21, 7, 19, 4, 3, 5, 3, 9, 16, 0, 11, 16, 0, 3,
365 3, 2, 0, 11, 0, 16, 6, 0, 17, 22, 13, 10, 12, 0, 0, 0,
366 0, 4, 1, 10, 30, 7, 10, 0, 23, 8, 0, 13, 15, 19, 8, 10,
368 float expected_valid[] = {
369 18, 21, 7, 19, 4, 9, 16, 0, 11, 16, 2, 0, 11, 0, 16, 22, 13, 10, 12, 0,
371 float weights[] = {
372 -2, -3, 1, 2, 2, -2, -3, 0, -3, 2, 2, -3, -3, -2, 0, 1, 2, 0, -1, -1,
374 float bias[] = { -3 };
376 CNN_CONFIG cnn_config = { 1,
383 filter_width,
384 filter_height,
389 weights,
390 bias,
391 PADDING_SAME_ZERO,
392 RELU,
395 BRANCH_NO_COPY,
396 BRANCH_NOC,
400 } } };
402 CNN_THREAD_DATA thread_data = { 1, nullptr };
404 RunCNNTest(image_width, image_height, input, expected_same, &cnn_config,
405 image_width, &thread_data, MSE_INT_TOL);
407 cnn_config.layer_config[0].pad = PADDING_SAME_REPLICATE;
409 RunCNNTest(image_width, image_height, input, expected_replicate, &cnn_config,
410 image_width, &thread_data, MSE_INT_TOL);
412 cnn_config.layer_config[0].pad = PADDING_VALID;
414 RunCNNTest(image_width, image_height, input, expected_valid, &cnn_config,
415 image_width, &thread_data, MSE_INT_TOL);
418 TEST_F(CNNTest, TestVaryingStridesVaryingDimImages) {
419 float weights[] = {
420 1, -5, -3, -4, -1, 1, 2, -3, 2, 2, -1, 1, -5, 1, 1,
421 -3, -5, 3, 1, 4, -2, -5, -2, -3, -5, 0, -1, -5, 2, -2,
422 -2, 1, -2, -4, 1, 3, -2, 2, 0, -3, 2, -3, -2, -3,
424 float bias[] = { 2 };
426 CNN_CONFIG cnn_config = { 1,
440 weights,
441 bias,
442 PADDING_SAME_ZERO,
443 NONE,
446 BRANCH_NO_COPY,
447 BRANCH_NOC,
452 } };
454 int image_height = 24;
455 int image_width = 17;
456 float input[] = {
457 -1, -3, 4, 4, -5, 4, 3, -5, -1, -3, 4, -4, 2, -3, 3, -5, 2, -1, -5,
458 1, -1, 3, 1, -3, -3, 4, 0, 2, -3, -5, -5, -4, 0, -5, -2, -3, -1, -2,
459 2, -5, 4, 4, 0, -4, -3, 1, -3, -5, -4, -4, 1, -2, -3, 3, -3, -3, -1,
460 -5, -5, -2, 3, 1, -1, -5, -5, 1, -4, -2, -1, -2, -4, -4, 2, -2, 2, 1,
461 -2, -4, -1, 1, -2, -5, 3, -2, -1, -1, -5, -3, 1, -2, -2, -3, -1, -2, -4,
462 -2, 1, -4, -1, 4, 3, -4, 0, 4, 2, 2, 4, -3, -5, 2, 2, 1, -1, -4,
463 -2, 1, 3, 2, 0, 4, -1, -3, 2, 1, -4, 2, 2, -4, -2, 0, -2, -1, 4,
464 4, 2, 3, -4, 2, -4, -5, 4, -1, -3, -1, 0, -4, 1, 3, -1, -3, -5, 3,
465 -2, -4, 1, 2, -2, -3, -3, -5, 1, -3, -1, 0, -1, 3, -4, -1, -5, -5, 1,
466 0, 0, -2, -2, 2, -2, 0, 0, 2, 0, -3, 0, -1, -4, -4, -1, 3, -4, -4,
467 -1, 0, -5, -3, -2, 4, -3, -4, -4, 0, -5, 1, -2, -3, -3, -4, 4, 3, 4,
468 3, 3, -1, 3, 1, -3, -2, 3, 3, 0, 2, -4, -3, 2, 2, 0, -2, 4, -2,
469 2, -2, -1, -4, -2, 2, -4, 3, -1, 4, 1, 1, 4, -1, -4, -4, 1, 1, -2,
470 4, -1, 3, 2, -3, 4, 3, 1, 4, 0, -4, 2, 0, 2, 4, -2, -2, 4, 2,
471 -1, -2, 1, -3, 2, 3, -5, -3, 4, 4, 2, -5, -4, -5, -2, -4, 2, 0, 2,
472 -5, 4, -4, -2, -5, 2, 1, 0, 4, 1, -2, -3, -4, -3, -4, 3, 3, 2, 0,
473 -3, 1, -5, 4, 0, 4, -1, 3, -5, -5, -2, -1, -1, 4, 3, 3, 4, 3, -4,
474 4, -3, -3, -1, -4, -1, -4, -1, -2, 4, -2, -4, 4, 4, -3, -4, -1, 1, 2,
475 -1, -2, -2, 3, 2, 2, -3, 0, -1, 0, 3, 2, -5, 0, -4, 0, 0, 2, -4,
476 -1, -1, 0, -2, 0, 1, 0, 0, 4, -5, -1, -5, 2, -1, 0, 2, -1, 1, 3,
477 -3, -5, -2, -3, 4, -2, -2, -1, -3, -4, -1, -2, -4, 1, 4, -3, -2, -1, 3,
478 -3, -2, 3, 2, 1, -4, -3, -5, 1,
480 float expected_1[] = {
481 41, -26, 5, 76, 13, 83, -21, 53, -54, -14, 21, 121,
484 CNN_THREAD_DATA thread_data = { 1, nullptr };
486 RunCNNTest(image_width, image_height, input, expected_1, &cnn_config,
487 image_width, &thread_data, MSE_INT_TOL);
489 cnn_config.layer_config[0].skip_width = 6;
490 cnn_config.layer_config[0].skip_height = 7;
492 float expected_2[] = {
493 21, -50, 41, 20, 72, 127, -21, 103, 62, -37, 83, -3,
495 RunCNNTest(image_width, image_height, input, expected_2, &cnn_config,
496 image_width, &thread_data, MSE_INT_TOL);
498 cnn_config.layer_config[0].skip_width = 3;
499 cnn_config.layer_config[0].skip_height = 10;
501 float expected_3[] = {
502 -26, -21, -35, 69, 49, 4, -51, -43, -56,
503 -41, 15, -44, 40, -62, 63, 38, 27, 47,
505 RunCNNTest(image_width, image_height, input, expected_3, &cnn_config,
506 image_width, &thread_data, MSE_INT_TOL);
508 cnn_config.layer_config[0].skip_width = 10;
509 cnn_config.layer_config[0].skip_height = 3;
511 float expected_4[] = {
512 21, 49, 28, 87, 50, 40, 102, 81, 58, 85, 51, 66, 36, 19, -37, -45,
515 RunCNNTest(image_width, image_height, input, expected_4, &cnn_config,
516 image_width, &thread_data, MSE_INT_TOL);
519 TEST_F(CNNTest, TestMaxPool) {
520 int image_width = 8;
521 int image_height = 8;
522 int stride = 3;
523 float input[] = {
524 1, -4, -4, 8, 0, 7, -5, -2, 8, 2, 2, 8, 5, -1, -1, 9,
525 -3, 0, -2, 0, 6, 3, -4, 8, 7, 8, 7, -1, 4, -1, 0, 2,
526 -5, -2, 8, 5, 5, 4, 2, 7, 4, 6, 2, 8, 8, -4, -3, -4,
527 -3, -1, 2, 3, 3, 6, -5, 8, 9, 5, 0, -2, -1, 6, 5, 7,
530 float expected[] = {
531 49, 58, 70, 68, 68, 70, 48, 57, 88,
534 float weights[] = {
535 3, 1, 3, 4, -1, 5, -2, 1, -4,
538 float bias[] = {
542 CNN_CONFIG cnn_config = { 1,
552 stride,
553 stride,
555 weights,
556 bias,
557 PADDING_SAME_ZERO,
558 NONE,
561 BRANCH_NO_COPY,
562 BRANCH_NOC,
566 } } };
568 CNN_THREAD_DATA thread_data = { 1, nullptr };
570 RunCNNTest(image_width, image_height, input, expected, &cnn_config,
571 image_width, &thread_data, MSE_INT_TOL);
574 TEST_F(CNNTest, TestDeconvolveNonActivationSingleLayerSingleKernel) {
575 int image_width = 4;
576 int image_height = 7;
577 float input[] = {
578 9, 6, 181, 9, 218, 30, 80, 108, 68, 216, 70, 128, 179, 228,
579 33, 212, 34, 14, 48, 27, 230, 23, 202, 113, 80, 56, 122, 112,
582 float expected_1_same[] = {
583 15, -30, 36, -525, 377, -193, 558, 531, 6, -24, -15, 124,
584 166, -561, -356, -754, -3, -3, -3, -3, -3, -3, -3, -3,
585 433, -311, 711, 381, 247, -317, 453, 129, 215, -627, -409, -885,
586 17, -255, -55, -647, -3, -3, -3, -3, -3, -3, -3, -3,
587 133, -719, 633, -225, 785, 191, 463, 79, 65, 9, 77, -853,
588 -365, -949, -15, -667, -3, -3, -3, -3, -3, -3, -3, -3,
589 355, -866, 990, 207, 747, 12, 520, -116, 176, -312, -133, -1370,
590 -426, -802, 143, -771, -3, -3, -3, -3, -3, -3, -3, -3,
591 65, -79, 127, -59, 135, -90, 195, 114, 31, -91, -57, -133,
592 17, -176, -72, -276, -3, -3, -3, -3, -3, -3, -3, -3,
593 457, -302, 733, 58, 470, -475, 829, 490, 227, -670, -440, -790,
594 153, -588, -294, -1150, -3, -3, -3, -3, -3, -3, -3, -3,
595 157, -251, 349, -185, 409, -293, 587, 251, 77, -187, -107, -369,
596 7, -481, -135, -827, -3, -3, -3, -3, -3, -3, -3, -3,
598 float expected_1_valid[] = {
599 -30, 15, -30, 36, -525, 377, -193, 558, 531, 24, 24, 6,
600 6, -24, -15, 124, 166, -561, -356, -754, -21, -39, -3, -3,
601 -3, -3, -3, -3, -3, -3, -3, -3, -3, -657, 433, -311,
602 711, 381, 247, -317, 453, 129, 321, 321, 215, 215, -627, -409,
603 -885, 17, -255, -55, -647, -219, -435, -3, -3, -3, -3, -3,
604 -3, -3, -3, -3, -3, -3, -207, 133, -719, 633, -225, 785,
605 191, 463, 79, 381, 381, 65, 65, 9, 77, -853, -365, -949,
606 -15, -667, -259, -515, -3, -3, -3, -3, -3, -3, -3, -3,
607 -3, -3, -3, -540, 355, -866, 990, 207, 747, 12, 520, -116,
608 633, 633, 176, 176, -312, -133, -1370, -426, -802, 143, -771, -427,
609 -851, -3, -3, -3, -3, -3, -3, -3, -3, -3, -3, -3,
610 -105, 65, -79, 127, -59, 135, -90, 195, 114, 78, 78, 31,
611 31, -91, -57, -133, 17, -176, -72, -276, -57, -111, -3, -3,
612 -3, -3, -3, -3, -3, -3, -3, -3, -3, -693, 457, -302,
613 733, 58, 470, -475, 829, 490, 336, 336, 227, 227, -670, -440,
614 -790, 153, -588, -294, -1150, -229, -455, -3, -3, -3, -3, -3,
615 -3, -3, -3, -3, -3, -3, -243, 157, -251, 349, -185, 409,
616 -293, 587, 251, 333, 333, 77, 77, -187, -107, -369, 7, -481,
617 -135, -827, -227, -451,
619 float weights_1[] = { -3, 2, -1, 3, 3, 1, 1, -3, -2, -4 };
620 float bias_1[] = { -3 };
622 CNN_CONFIG cnn_config = { 1,
635 weights_1,
636 bias_1,
637 PADDING_SAME_ZERO,
638 NONE,
641 BRANCH_NO_COPY,
642 BRANCH_NOC,
646 } } };
648 CNN_THREAD_DATA thread_data = { 1, nullptr };
650 RunCNNTest(image_width, image_height, input, expected_1_same, &cnn_config,
651 image_width, &thread_data, MSE_INT_TOL);
653 // Change padding to valid
654 cnn_config.layer_config[0].pad = PADDING_VALID;
656 RunCNNTest(image_width, image_height, input, expected_1_valid, &cnn_config,
657 image_width, &thread_data, MSE_INT_TOL);
659 float expected_12_same[] = {
660 15, -12, 6, 36, -9, -528, 377, -184, 513, 558, -12, 24,
661 6, -30, -15, -33, -21, 166, 154, -546, -356, -718, -30, -21,
662 433, -221, 561, 711, -33, -153, 247, -83, -87, 453, -111, 321,
663 215, -657, -409, -845, -93, 17, -43, -243, -55, -215, -327, -219,
664 133, -71, -447, 633, -219, 435, 785, -73, -177, 463, -131, 381,
665 65, -207, 77, -59, -651, -365, -797, -213, -15, -155, -387, -259,
666 355, -182, -150, 990, -231, 582, 747, -36, -540, 520, -215, 633,
667 176, -540, -133, -491, -687, -426, -882, -102, 143, 77, -639, -427,
668 65, -37, 57, 127, -17, -105, 135, -51, 60, 195, -30, 78,
669 31, -105, -57, -125, -45, 17, -11, -147, -72, -168, -84, -57,
670 457, -233, 618, 733, -26, -540, 470, -205, 264, 829, -116, 336,
671 227, -693, -440, -900, -72, 153, 107, -609, -294, -698, -342, -229,
672 157, -83, 69, 349, -59, -201, 409, -125, 27, 587, -115, 333,
673 77, -243, -107, -267, -171, 7, -105, -369, -135, -379, -339, -227,
675 float expected_12_valid[] = {
676 -30, 15, -12, 6, 36, -9, -528, 377, -184, 513, 558, -12,
677 24, 24, 6, 6, -30, -15, -33, -21, 166, 154, -546, -356,
678 -718, -30, -21, -39, -657, 433, -221, 561, 711, -33, -153, 247,
679 -83, -87, 453, -111, 321, 321, 215, 215, -657, -409, -845, -93,
680 17, -43, -243, -55, -215, -327, -219, -435, -207, 133, -71, -447,
681 633, -219, 435, 785, -73, -177, 463, -131, 381, 381, 65, 65,
682 -207, 77, -59, -651, -365, -797, -213, -15, -155, -387, -259, -515,
683 -540, 355, -182, -150, 990, -231, 582, 747, -36, -540, 520, -215,
684 633, 633, 176, 176, -540, -133, -491, -687, -426, -882, -102, 143,
685 77, -639, -427, -851, -105, 65, -37, 57, 127, -17, -105, 135,
686 -51, 60, 195, -30, 78, 78, 31, 31, -105, -57, -125, -45,
687 17, -11, -147, -72, -168, -84, -57, -111, -693, 457, -233, 618,
688 733, -26, -540, 470, -205, 264, 829, -116, 336, 336, 227, 227,
689 -693, -440, -900, -72, 153, 107, -609, -294, -698, -342, -229, -455,
690 -243, 157, -83, 69, 349, -59, -201, 409, -125, 27, 587, -115,
691 333, 333, 77, 77, -243, -107, -267, -171, 7, -105, -369, -135,
692 -379, -339, -227, -451,
695 // Change skip_width, skip_height to {2, 3}
696 cnn_config.layer_config[0].skip_width = 3;
697 cnn_config.layer_config[0].skip_height = 2;
698 // Set padding to same
699 cnn_config.layer_config[0].pad = PADDING_SAME_ZERO;
701 RunCNNTest(image_width, image_height, input, expected_12_same, &cnn_config,
702 image_width, &thread_data, MSE_INT_TOL);
704 // Change padding to valid
705 cnn_config.layer_config[0].pad = PADDING_VALID;
706 RunCNNTest(image_width, image_height, input, expected_12_valid, &cnn_config,
707 image_width, &thread_data, MSE_INT_TOL);
709 cnn_config.layer_config[0].filter_width = 4;
710 cnn_config.layer_config[0].filter_height = 3;
711 float weights_2[] = { -1, -3, -1, -3, 0, 2, -2, 4, 3, 0, 1, 4 };
712 float bias_2[] = { -4 };
713 cnn_config.layer_config[0].weights = weights_2;
714 cnn_config.layer_config[0].bias = bias_2;
716 cnn_config.layer_config[0].skip_width = 5;
717 cnn_config.layer_config[0].skip_height = 2;
718 float expected_2_same[] = {
719 -13, -31, -13, -31, -4, -10, -22, -10, -22, -4, -185, -547,
720 -185, -547, -4, -13, -31, -13, -31, -4, -4, 14, -22, 32,
721 -4, -4, 8, -16, 20, -4, -4, 358, -366, 720, -4, -4,
722 14, -22, 32, -4, -195, -658, -213, -622, -4, -16, -94, -28,
723 -70, -4, 459, -244, 97, 480, -4, -85, -328, -103, -292, -4,
724 -4, 432, -440, 868, -4, -4, 56, -64, 116, -4, -4, 156,
725 -164, 316, -4, -4, 212, -220, 428, -4, 582, -208, 146, 664,
726 -4, -130, -652, -190, -532, -4, 166, -214, 6, 106, -4, 192,
727 -388, -24, 44, -4, -4, 132, -140, 268, -4, -4, 428, -436,
728 860, -4, -4, 136, -144, 276, -4, -4, 252, -260, 508, -4,
729 21, -541, -115, -269, -4, 416, -688, -16, 176, -4, 173, -103,
730 33, 177, -4, 168, -640, -88, -128, -4, -4, 354, -362, 712,
731 -4, -4, 452, -460, 908, -4, -4, 62, -70, 128, -4, -4,
732 420, -428, 844, -4, 499, -106, 141, 610, -4, 666, -46, 210,
733 866, -4, 47, -148, -19, -16, -4, 605, -85, 181, 763, -4,
734 -4, 64, -72, 132, -4, -4, 24, -32, 52, -4, -4, 92,
735 -100, 188, -4, -4, 50, -58, 104, -4, -132, -694, -200, -558,
736 -4, 15, -73, -13, -17, -4, -62, -610, -158, -418, -4, -36,
737 -343, -90, -235, -4, -4, 456, -464, 916, -4, -4, 42, -50,
738 88, -4, -4, 400, -408, 804, -4, -4, 222, -230, 448, -4,
739 606, -244, 146, 676, -4, 9, -172, -37, -80, -4, 480, -370,
740 76, 438, -4, 223, -340, -3, 112, -4, -4, 156, -164, 316,
741 -4, -4, 108, -116, 220, -4, -4, 240, -248, 484, -4, -4,
742 220, -228, 444, -4,
744 float expected_2_valid[] = {
745 -13, -31, -13, -31, -4, -10, -22, -10, -22, -4, -185, -547,
746 -185, -547, -4, -13, -31, -13, -31, -4, 14, -22, 32, -4,
747 -4, 8, -16, 20, -4, -4, 358, -366, 720, -4, -4, 14,
748 -22, 32, -195, -658, -213, -622, -4, -16, -94, -28, -70, -4,
749 459, -244, 97, 480, -4, -85, -328, -103, -292, -4, 432, -440,
750 868, -4, -4, 56, -64, 116, -4, -4, 156, -164, 316, -4,
751 -4, 212, -220, 428, 582, -208, 146, 664, -4, -130, -652, -190,
752 -532, -4, 166, -214, 6, 106, -4, 192, -388, -24, 44, -4,
753 132, -140, 268, -4, -4, 428, -436, 860, -4, -4, 136, -144,
754 276, -4, -4, 252, -260, 508, 21, -541, -115, -269, -4, 416,
755 -688, -16, 176, -4, 173, -103, 33, 177, -4, 168, -640, -88,
756 -128, -4, 354, -362, 712, -4, -4, 452, -460, 908, -4, -4,
757 62, -70, 128, -4, -4, 420, -428, 844, 499, -106, 141, 610,
758 -4, 666, -46, 210, 866, -4, 47, -148, -19, -16, -4, 605,
759 -85, 181, 763, -4, 64, -72, 132, -4, -4, 24, -32, 52,
760 -4, -4, 92, -100, 188, -4, -4, 50, -58, 104, -132, -694,
761 -200, -558, -4, 15, -73, -13, -17, -4, -62, -610, -158, -418,
762 -4, -36, -343, -90, -235, -4, 456, -464, 916, -4, -4, 42,
763 -50, 88, -4, -4, 400, -408, 804, -4, -4, 222, -230, 448,
764 606, -244, 146, 676, -4, 9, -172, -37, -80, -4, 480, -370,
765 76, 438, -4, 223, -340, -3, 112, -4, 156, -164, 316, -4,
766 -4, 108, -116, 220, -4, -4, 240, -248, 484, -4, -4, 220,
767 -228, 444, 236, -4, 76, 316, -4, 164, -4, 52, 220, -4,
768 362, -4, 118, 484, -4, 332, -4, 108, 444,
770 // Set padding to same
771 cnn_config.layer_config[0].pad = PADDING_SAME_ZERO;
773 RunCNNTest(image_width, image_height, input, expected_2_same, &cnn_config,
774 image_width, &thread_data, MSE_INT_TOL);
776 cnn_config.layer_config[0].pad = PADDING_VALID;
778 RunCNNTest(image_width, image_height, input, expected_2_valid, &cnn_config,
779 image_width, &thread_data, MSE_INT_TOL);
781 cnn_config.layer_config[0].skip_width = 2;
782 cnn_config.layer_config[0].skip_height = 5;
783 float expected_21_same[] = {
784 -31, -19, -49, -191, -565, -194, -574, -13, 14, -22, 44, -16,
785 382, -366, 738, -22, -4, 23, 32, 545, 20, 204, 720, 5,
786 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
787 -4, -4, -4, -4, -658, -252, -748, -114, -334, -192, -568, -112,
788 432, -440, 928, -64, 276, -164, 532, -220, -4, 304, 868, 266,
789 116, 400, 316, 104, -4, -4, -4, -4, -4, -4, -4, -4,
790 -4, -4, -4, -4, -4, -4, -4, -4, -208, -288, -856, -290,
791 -862, -202, -598, -132, 132, -140, 700, -436, 1000, -144, 532, -260,
792 -4, 712, 268, 422, 860, 450, 276, 124, -4, -4, -4, -4,
793 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
794 -541, -411, -1225, -265, -787, -249, -739, -216, 354, -362, 1168, -460,
795 974, -70, 552, -428, -4, 859, 712, 323, 908, 665, 128, 208,
796 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
797 -4, -4, -4, -4, -106, -52, -148, -66, -190, -79, -229, -31,
798 64, -72, 160, -32, 148, -100, 242, -58, -4, 72, 132, 154,
799 52, 125, 188, 23, -4, -4, -4, -4, -4, -4, -4, -4,
800 -4, -4, -4, -4, -4, -4, -4, -4, -694, -257, -763, -229,
801 -679, -319, -949, -117, 456, -464, 962, -50, 492, -408, 1030, -230,
802 -4, 295, 916, 625, 88, 537, 804, 109, -4, -4, -4, -4,
803 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
804 -244, -140, -412, -182, -538, -238, -706, -116, 156, -164, 428, -116,
805 464, -248, 708, -228, -4, 244, 316, 418, 220, 454, 484, 108,
806 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
807 -4, -4, -4, -4,
809 float expected_21_valid[] = {
810 -13, -31, -19, -49, -191, -565, -194, -574, -13, -31, -4, 14,
811 -22, 44, -16, 382, -366, 738, -22, 32, 23, -4, 23, 32,
812 545, 20, 204, 720, 5, 32, -4, -4, -4, -4, -4, -4,
813 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
814 -4, -4, -222, -658, -252, -748, -114, -334, -192, -568, -112, -328,
815 -4, 432, -440, 928, -64, 276, -164, 532, -220, 428, 650, -4,
816 304, 868, 266, 116, 400, 316, 104, 428, -4, -4, -4, -4,
817 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
818 -4, -4, -4, -4, -72, -208, -288, -856, -290, -862, -202, -598,
819 -132, -388, -4, 132, -140, 700, -436, 1000, -144, 532, -260, 508,
820 200, -4, 712, 268, 422, 860, 450, 276, 124, 508, -4, -4,
821 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
822 -4, -4, -4, -4, -4, -4, -183, -541, -411, -1225, -265, -787,
823 -249, -739, -216, -640, -4, 354, -362, 1168, -460, 974, -70, 552,
824 -428, 844, 533, -4, 859, 712, 323, 908, 665, 128, 208, 844,
825 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
826 -4, -4, -4, -4, -4, -4, -4, -4, -38, -106, -52, -148,
827 -66, -190, -79, -229, -31, -85, -4, 64, -72, 160, -32, 148,
828 -100, 242, -58, 104, 98, -4, 72, 132, 154, 52, 125, 188,
829 23, 104, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
830 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -234, -694,
831 -257, -763, -229, -679, -319, -949, -117, -343, -4, 456, -464, 962,
832 -50, 492, -408, 1030, -230, 448, 686, -4, 295, 916, 625, 88,
833 537, 804, 109, 448, -4, -4, -4, -4, -4, -4, -4, -4,
834 -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4,
835 -84, -244, -140, -412, -182, -538, -238, -706, -116, -340, -4, 156,
836 -164, 428, -116, 464, -248, 708, -228, 444, 236, -4, 244, 316,
837 418, 220, 454, 484, 108, 444,
840 cnn_config.layer_config[0].pad = PADDING_SAME_ZERO;
842 RunCNNTest(image_width, image_height, input, expected_21_same, &cnn_config,
843 image_width, &thread_data, MSE_INT_TOL);
845 cnn_config.layer_config[0].pad = PADDING_VALID;
847 RunCNNTest(image_width, image_height, input, expected_21_valid, &cnn_config,
848 image_width, &thread_data, MSE_INT_TOL);
851 TEST_F(CNNTest, TestLargeKernelsAndStrides) {
852 float input_10x11[] = {
853 4, 4, 2, 4, 2, -5, -2, 3, -1, 0, 0, 1, 2, 0, -5, -2, -5, 1, -3,
854 -1, 4, -3, 2, -2, 1, 0, 1, -3, -3, -4, -2, -2, 1, -4, -1, 4, 1, -4,
855 -4, -4, 3, 2, -5, 3, -5, 1, 2, -4, 1, -1, 3, 4, -2, 3, -3, 3, 0,
856 2, -4, -5, -5, -2, -1, -2, 1, 1, 1, -2, 4, -5, 4, -1, -1, 2, 3, -4,
857 2, 2, 3, 0, 0, 1, 0, 3, 2, 3, 1, -2, 3, -4, 3, 2, 4, -2, 0,
858 4, -4, 1, -3, -3, -3, -5, 1, -3, -5, 0, 4, -1, -3, 2,
861 float weights_10x11[] = {
862 -3, 4, -4, -3, -5, 1, -2, 3, 1, -4, -4, 0, -1, 0, 3, 1, -3, -2, 0,
863 -1, 1, 3, -4, -4, -3, -3, -2, 4, 3, -5, 4, 2, -3, 4, -2, -1, 2, -1,
864 -5, 0, -3, 0, 3, -5, -5, 3, -4, -1, -5, 3, 4, 0, 4, -5, 2, -1, 2,
865 -1, -1, -1, -5, 0, -4, 3, -1, 1, 1, -1, 3, 2, -5, -4, 0, -4, 4, -5,
866 -3, 4, -5, 2, -5, -4, -4, -1, 3, 3, 0, 2, -4, 1, -2, 1, 1, 0, 3,
867 -2, 0, 1, 2, 4, -3, -1, -5, -5, 2, -4, 1, 1, 2, -4, -2, -2, 2, 1,
868 3, 4, -5, 1, -1, -3, -3, -1, -2, -5, 1, -1, 0, 1, 4, 4, 0, 0, 4,
869 -3, -1, -5, -3, 0, 1, 1, 1, -5, 3, 4, 3, -5, 3, -2, -2, 0, -4, 0,
870 0, -2, 1, -4, -1, 0, -5, -2, -2, -5, -3, -3, 1, 1, -3, 2, 4, 2, 4,
871 -4, -3, 3, 1, 1, 3, -4, 4, -2, -3, -3, -3, -3, -4, -2, 3, -5, 2, 4,
872 -1, -4, -4, 4, -2, -1, 3, -3, -4, -4, -2, 4, 1, 0, 2, -1, 4, -3, 1,
873 4, -3, 4, 4, 0, -4, 3, -2, -3, 2, 3, -1, -3, 2, 1, 4, -2, -3, 1,
874 4, -2, 2, -2, -5, -2, 1, 4, -1, -4, 4, -5, 2, -5, -4, -1, -2, 3, 1,
875 2, 1, -5, 1, -5, -4, -1, -2, 2, -2, -4, -3, -2, -2, 4, -1, 2, 2, -4,
876 2, -2, 4, -4, -2, -2, 1, -1, 1, 1, 1, -4, -5, -2, 3, -4, -1, 3, -2,
877 3, 2, -5, -4, 0, 3, -2, -4, -5, 3, -2, -4, 2, -2, 1, -4, 0, 2, -5,
878 1, -4, -1, -1, 4, -5, -4, 0, -5, -4, -3, -5, -4, 0, 2, 0, -4, 2, -2,
879 1, 1, -3, 2, 0, -4, 0, -4, 1, 0, -5, -1, -1, -1, -5, 4, 2, 2, -4,
880 3, -2, -2, 2, -3, -2, -1, 2, -4, -5, 2, -2, -4, -5, -5, -1, 2, -1, 0,
881 -5, -2, -2, -5, 0, 1, -1, -5, 0, 3, 2, 3, 0, -3, -2, 0, -5, -1, -2,
882 2, -4, -1, 2, 2, -5, 2, -4, 0, 3, -3, 1, 0, 0, 1, -5, -3, 1, -1,
883 0, -4, -3, 2, -4, -4, 4, -1, 0, 1, 2, -4, -5, 4, -2, 1, -4, -4, -3,
884 -1, -1, 1, -1, -4, -1, -4, -3, 2, -1, -2, -4, 1, 1, 0, -2, 0, -4, 3,
885 -3, 0, -4, -1, -4, 2, -1, -2, -5, -1, -2, -3, 3, -1, 0, -3, 0, 1, -5,
886 1, -5, 0, 1,
889 float bias_10x11[] = { 3 };
891 float expected_10x11[] = {
892 118,
895 CNN_CONFIG cnn_config = { 1,
908 weights_10x11,
909 bias_10x11,
910 PADDING_SAME_ZERO,
911 NONE,
914 BRANCH_NO_COPY,
915 BRANCH_NOC,
919 } } };
921 int image_height = 10;
922 int image_width = 11;
924 CNN_THREAD_DATA thread_data = { 1, nullptr };
926 RunCNNTest(image_width, image_height, input_10x11, expected_10x11,
927 &cnn_config, image_width, &thread_data, MSE_INT_TOL);
929 float input_11x10[] = {
930 -2, -2, 3, -5, -1, -3, 1, 3, 2, 1, 1, -5, 4, 1, 3, -5, 3, -3, -5,
931 0, -1, -3, -3, 1, 1, -5, -1, -5, -5, -3, 0, 1, -3, -1, -3, -3, 0, 3,
932 4, -4, -1, 3, -3, -1, -3, 1, -3, -2, -1, -4, -3, 2, -4, 1, -4, -1, -3,
933 -5, -1, 2, 3, 0, 2, 2, -5, 4, 1, 2, -1, -4, 4, -4, -4, 0, -1, 1,
934 -1, 1, -3, -3, -2, 1, 2, 4, 4, 4, -3, -3, 0, 1, 0, 1, 4, 1, 3,
935 4, -3, -2, -4, 4, 2, 0, 3, 4, -1, 2, -2, 1, -3, -2,
938 float weights_11x10[] = {
939 4, -1, 1, -1, 2, 4, 3, 3, -4, 3, -5, 1, -1, -1, -2, -2, 0, 2, -3,
940 -2, 3, -5, -1, 0, -1, -2, -2, -1, 2, 4, 3, 1, 0, 0, -3, 3, -4, -1,
941 -5, 4, -2, -2, 1, 2, -1, -3, 1, 2, -5, 1, -3, 3, 3, 0, -4, -4, -5,
942 -3, -4, -4, 4, -2, 4, 4, -2, 2, -5, -1, -2, -5, -1, 4, -3, 3, -2, 0,
943 -4, -3, 0, -1, -2, 4, 2, 0, -2, -5, -4, 1, 4, -4, -2, 2, -2, 1, 1,
944 -4, 1, -4, -4, -2, 4, 2, -1, -5, -5, 1, -3, -3, 3, -3, -5, -3, 4, -1,
945 -1, -3, 0, -4, 3, -1, 0, -2, 0, -5, -2, -5, 2, 0, -5, 2, 3, -2, 2,
946 4, -1, 1, -3, 2, 3, 2, 0, -5, -4, -5, 2, 1, 1, -1, -2, 3, 4, 2,
947 -2, 4, -2, 3, 1, -4, -3, -1, 4, 4, -3, -5, -2, 2, 0, 3, -2, 3, -1,
948 -4, 0, -2, 0, 3, 4, -2, -3, -2, 0, 3, 4, 2, -4, 0, 1, 2, 2, -1,
949 -1, 4, 1, 4, -2, -1, -1, -5, 1, -3, 3, 3, -1, -4, 3, -5, 0, 0, -1,
950 -4, -1, -2, 4, -2, 3, 3, -3, 1, -1, 2, -1, 4, 4, -2, -2, 4, -2, 0,
951 3, -3, -5, -1, -2, 4, -4, 2, -4, 0, -2, 3, -3, 2, 2, -2, -5, -1, 4,
952 3, -2, -1, 3, 3, -1, 3, 0, -3, 0, 4, 2, 0, -1, 4, 1, 1, 2, 1,
953 3, 1, 1, 1, -3, -5, -4, 4, -4, 2, 0, 0, -4, 1, 4, -5, 4, 4, 0,
954 1, 0, -2, -4, -4, -3, 0, 1, -5, 4, 0, -3, -2, -4, 2, 4, 1, -5, 1,
955 -4, 1, 0, -3, -3, 0, 2, -5, 4, 3, -2, -5, 3, 1, -1, 0, 3, -2, -2,
956 3, -2, -5, 4, 1, -2, 2, -1, 0, 4, 0, -5, 3, -2, 1, 2, 1, -5, -3,
957 -2, -5, 4, -4, 0, 3, 2, -1, -4, -1, 2, 1, -2, 3, -1, -4, 2, 0, -3,
958 1, -1, 2, -5, -4, -1, -5, 1, 4, 3, 4, 2, -3, 1, -5, -1, 3, 0, -1,
959 -4, 3, 4, -5, 4, 4, -3, 2, -3, -1, -3, -5, -3, 2, -3, -2, 1, 1, 0,
960 -5, 3, 2, 1, -5, 1, 1, 1, 3, 4, -4, -1, -2, 0, -5, -3, -5, -2, -4,
961 3, 3, 3, 4, 0, -4, -1, -5, 0, -3, 1, 4, 4, -4, 4, -5, -5, -1, -2,
962 -5, 3, -4, 4, 3, 0, -3, 2, -2, 0, 0, 4, 4, 0, -2, 1, -1, -3, 2,
963 -1, 1, -3, -5,
966 float bias_11x10[] = {
970 float expected_11x10[] = {
971 36, -84, 95, 45, 18, 46, 77, -54, -99, -149, 66, 49, 161, 11,
972 39, 61, -66, 61, 4, -3, 34, -44, -23, 31, 64, 29, 47, 72,
973 -27, -27, 121, -3, 100, 1, 30, -78, -12, -89, -59, 8, -16, 112,
974 91, -102, -26, -4, 30, 54, 4, -84, -24, -58, 27, -53, -33, 5,
975 53, -26, 63, 50, -103, -130, -23, 6, -104, -207, 73, 23, 77, 132,
976 38, 32, -130, -44, -60, 7, 27, 176, 45, -32, -2, 99, -97, 63,
977 69, 126, 47, 63, 136, -57, 5, 16, -40, -157, 8, 38, -44, -10,
978 91, 7, 122, 140, 30, -105, 4, -1, 113, 64, 180, 141,
981 cnn_config.layer_config[0].weights = weights_11x10;
982 cnn_config.layer_config[0].bias = bias_11x10;
983 cnn_config.layer_config[0].filter_width = 20;
984 cnn_config.layer_config[0].filter_height = 23;
985 cnn_config.layer_config[0].skip_width = 1;
986 cnn_config.layer_config[0].skip_height = 1;
987 image_height = 11;
988 image_width = 10;
990 RunCNNTest(image_width, image_height, input_11x10, expected_11x10,
991 &cnn_config, image_width, &thread_data, MSE_INT_TOL);
994 TEST_F(CNNTest, TestSoftsignSingleLayer) {
995 int image_width = 8;
996 int image_height = 8;
997 int filter_height = 5;
998 int filter_width = 4;
999 float input[] = {
1000 -0.5220f, 0.8410f, -0.8990f, -0.0090f, 0.6710f, -0.9470f, -0.8240f,
1001 -0.0870f, 0.5380f, 0.4750f, 0.570f, -0.3760f, -0.6960f, -0.5940f,
1002 -0.3830f, 0.080f, -0.0980f, -0.4940f, -0.4030f, 0.9460f, -0.6020f,
1003 0.4220f, 0.6190f, 0.6640f, -0.9210f, -0.1470f, -0.2480f, -0.1120f,
1004 -0.580f, -0.0650f, 0.3330f, 0.9860f, -0.7430f, 0.7610f, 0.4840f,
1005 0.1030f, 0.9570f, 0.6120f, -0.5240f, -0.1220f, -0.5850f, -0.270f,
1006 0.7840f, -0.9790f, 0.7290f, -0.30f, -0.6460f, 0.0780f, 0.4750f,
1007 -0.0510f, 0.4550f, 0.3850f, -0.7230f, 0.4460f, -0.6260f, -0.810f,
1008 0.8720f, -0.2120f, -0.580f, -0.9510f, -0.8430f, -0.1340f, -0.0850f,
1009 0.9190f,
1011 float expected_same[] = {
1012 0.430f, 0.660f, 0.5510f, -0.610f, 0.450f, -0.1610f, 0.0520f, 0.3240f,
1013 0.6820f, 0.3820f, 0.6360f, 0.7480f, 0.3080f, 0.090f, 0.3910f, 0.1730f,
1014 0.340f, 0.6660f, -0.4990f, 0.4280f, 0.1540f, 0.120f, 0.4670f, 0.6150f,
1015 -0.3880f, 0.7590f, 0.4190f, 0.7350f, 0.5310f, -0.5160f, -0.1760f, 0.6790f,
1016 -0.6780f, 0.5470f, 0.5750f, -0.6420f, 0.7210f, -0.4620f, 0.5430f, 0.770f,
1017 -0.1990f, 0.3950f, 0.7860f, -0.4380f, 0.7540f, 0.2640f, -0.6430f, 0.4510f,
1018 -0.1260f, 0.1590f, -0.2110f, -0.0560f, 0.6570f, 0.680f, 0.5870f, 0.4720f,
1019 0.4040f, 0.3630f, 0.670f, 0.2360f, 0.410f, 0.6980f, -0.5350f, 0.3940f,
1021 float expected_replicate[] = {
1022 0.540f, 0.7230f, -0.3530f, -0.2130f, 0.7440f, -0.4470f, -0.6260f,
1023 -0.2050f, 0.7230f, 0.4630f, 0.5920f, 0.7440f, 0.6080f, 0.3130f,
1024 -0.5670f, -0.4720f, 0.5480f, 0.6660f, -0.4990f, 0.4280f, 0.1540f,
1025 0.120f, 0.3390f, 0.6090f, 0.4160f, 0.7590f, 0.4190f, 0.7350f,
1026 0.5310f, -0.5160f, -0.490f, 0.4450f, -0.610f, 0.5470f, 0.5750f,
1027 -0.6420f, 0.7210f, -0.4620f, 0.3150f, 0.7370f, -0.5820f, 0.3950f,
1028 0.7860f, -0.4380f, 0.7540f, 0.2640f, -0.7430f, -0.5340f, -0.6270f,
1029 0.4430f, 0.4730f, 0.4570f, 0.7450f, 0.630f, 0.2620f, 0.3140f,
1030 -0.1840f, 0.1810f, 0.7210f, 0.2760f, 0.6430f, 0.6720f, -0.4390f,
1031 0.2040f,
1033 float expected_valid[] = {
1034 0.6660f, -0.4990f, 0.4280f, 0.1540f, 0.120f, 0.7590f, 0.4190f,
1035 0.7350f, 0.5310f, -0.5160f, 0.5470f, 0.5750f, -0.6420f, 0.7210f,
1036 -0.4620f, 0.3950f, 0.7860f, -0.4380f, 0.7540f, 0.2640f,
1038 float weights[] = {
1039 0.6210f, 0.3710f, -0.2770f, -0.7230f, -0.2450f, 0.6770f, 0.3080f,
1040 -0.9880f, -0.080f, 0.7190f, -0.6760f, -0.0170f, -0.8970f, 0.8260f,
1041 0.7390f, -0.4550f, -0.4260f, -0.6330f, 0.0880f, -0.9390f,
1043 float bias[] = {
1044 0.750f,
1047 CNN_CONFIG cnn_config = { 1,
1054 filter_width,
1055 filter_height,
1060 weights,
1061 bias,
1062 PADDING_SAME_ZERO,
1063 SOFTSIGN,
1066 BRANCH_NO_COPY,
1067 BRANCH_NOC,
1071 } } };
1073 CNN_THREAD_DATA thread_data = { 1, nullptr };
1075 RunCNNTest(image_width, image_height, input, expected_same, &cnn_config,
1076 image_width, &thread_data, MSE_FLOAT_TOL);
1078 cnn_config.layer_config[0].pad = PADDING_SAME_REPLICATE;
1080 RunCNNTest(image_width, image_height, input, expected_replicate, &cnn_config,
1081 image_width, &thread_data, MSE_FLOAT_TOL);
1083 cnn_config.layer_config[0].pad = PADDING_VALID;
1085 RunCNNTest(image_width, image_height, input, expected_valid, &cnn_config,
1086 image_width, &thread_data, MSE_FLOAT_TOL);
1089 TEST_F(CNNTest, TestBranchTensorAdd) {
1090 int filter_width = 2;
1091 int filter_height = 3;
1093 int image_width = 4;
1094 int image_height = 4;
1096 float input[] = {
1097 -3, -2, -2, 0, -1, 3, 2, -2, 1, 3, 4, 0, 2, -5, -4, 0,
1100 float weights[] = {
1101 -3, -1, 4, -1, -3, 3, 3, 0, 2, 0, 3, 2, 4, 4, 4, -5, 1, -4,
1102 2, -4, 1, -3, 0, 4, -5, 4, 0, -4, -3, -1, 0, 0, -2, 0, 0, 2,
1103 -5, -1, 1, -3, 3, 4, 3, 0, 1, -1, 1, 1, 2, 4, -2, -5, 2, -2,
1104 3, -2, 4, -1, 0, 2, 3, 2, -2, -1, -3, 1, 3, 4, -1, -3, 0, -4,
1105 4, 2, -3, -3, -1, 0, 1, 0, 3, 3, -3, 0, 3, 2, -5, -3, 4, -5,
1106 3, -1, -1, -3, 0, 1, -1, -4, 2, 4, -1, 4, -1, 1, 3, 4, 4, 4,
1107 0, -1, -3, -3, -3, -3, 2, -3, -2, 2, 3, -3,
1110 float bias[] = {
1111 3, 4, -1, -1, 2, 1, -2, 1, 4, 1, 3,
1114 float expected[] = {
1115 -11502, -4101, -3424, 668, -17950, -5470, -5504, 626,
1116 4835, 446, 1779, -3483, 3679, -4214, 4578, -105,
1119 int channels = 2;
1121 CNN_CONFIG cnn_config = { 6,
1128 filter_width,
1129 filter_height,
1130 channels,
1134 weights,
1135 bias,
1136 PADDING_SAME_ZERO,
1137 NONE,
1140 BRANCH_NO_COPY,
1141 BRANCH_NOC,
1147 channels,
1148 filter_width,
1149 filter_height,
1150 channels,
1154 nullptr,
1155 nullptr,
1156 PADDING_SAME_ZERO,
1157 NONE,
1160 BRANCH_INPUT,
1161 BRANCH_NOC,
1163 0x02,
1165 0x00,
1171 channels,
1172 filter_width,
1173 filter_height,
1174 channels,
1178 nullptr,
1179 nullptr,
1180 PADDING_SAME_ZERO,
1181 NONE,
1184 BRANCH_NO_COPY,
1185 BRANCH_NOC,
1191 channels,
1192 filter_width,
1193 filter_height,
1194 channels,
1198 nullptr,
1199 nullptr,
1200 PADDING_SAME_ZERO,
1201 NONE,
1204 BRANCH_NO_COPY,
1205 BRANCH_NOC,
1211 channels,
1212 filter_width,
1213 filter_height,
1214 channels,
1218 nullptr,
1219 nullptr,
1220 PADDING_SAME_ZERO,
1221 NONE,
1224 BRANCH_NO_COPY,
1225 BRANCH_ADD,
1227 0x00,
1229 0x02,
1235 channels,
1236 filter_width,
1237 filter_height,
1242 nullptr,
1243 nullptr,
1244 PADDING_SAME_ZERO,
1245 NONE,
1248 BRANCH_NO_COPY,
1249 BRANCH_NOC,
1253 } } };
1255 // Weights and biases need to be specified separately because
1256 // of the offset.
1257 AssignLayerWeightsBiases(&cnn_config, weights, bias);
1259 CNN_THREAD_DATA thread_data = { 1, nullptr };
1261 RunCNNTest(image_width, image_height, input, expected, &cnn_config,
1262 image_width, &thread_data, MSE_INT_TOL);
1265 TEST_F(CNNTest, TestBranchTensorConcatenation) {
1266 int filter_width = 2;
1267 int filter_height = 3;
1269 int image_width = 4;
1270 int image_height = 4;
1272 float input[] = {
1273 -3, -2, -2, 0, -1, 3, 2, -2, 1, 3, 4, 0, 2, -5, -4, 0,
1276 float weights[] = {
1277 3, 0, 2, 0, 2, 3, 1, -3, 1, -5, -3, 0, -4, 4, 0, -5, 0, -5, -1,
1278 -2, -5, 0, -3, 2, -4, 2, 0, 2, -1, 0, -4, 3, 0, 0, -1, -5, 2, -1,
1279 4, -4, -2, -3, -3, 3, 4, -2, -1, -4, -1, 4, 4, -1, 4, 3, -4, 2, -2,
1280 -4, -3, -2, 3, -3, -5, -1, 3, -2, 4, 1, -4, -3, -5, -5, -3, 4, -2, -2,
1281 -1, -5, -5, 0, -1, -2, -3, 3, -4, -5, 2, -3, 1, 0, -5, 2, 2, -2, 0,
1282 2, 2, -2, 4, 2, 2, 0, 1, -5, -3, 0, 2, -2, 1, 2, -5, 2, 3, 3,
1283 -1, 3, 0, -3, 3, -4, -4, 3, 3, -4, -2, 2, -2, 2, -2, -1, 3, 0,
1286 float bias[] = {
1287 -3, -5, 4, -4, -3, -2, 0, 3, -4, 4, -3,
1290 float expected[] = {
1291 -33533, -32087, -6741, -2124, 39979, 41453, 14034, 689,
1292 -22611, -42203, -14882, -239, 15781, 15963, 9524, 837,
1295 int channels = 2;
1297 CNN_CONFIG cnn_config = { 6,
1304 filter_width,
1305 filter_height,
1306 channels,
1310 weights,
1311 bias,
1312 PADDING_SAME_ZERO,
1313 NONE,
1316 BRANCH_NO_COPY,
1317 BRANCH_NOC,
1323 channels,
1324 filter_width,
1325 filter_height,
1326 channels,
1330 nullptr,
1331 nullptr,
1332 PADDING_SAME_ZERO,
1333 NONE,
1336 BRANCH_INPUT,
1337 BRANCH_NOC,
1339 0x02,
1341 0x00,
1347 channels,
1348 filter_width,
1349 filter_height,
1350 channels,
1354 nullptr,
1355 nullptr,
1356 PADDING_SAME_ZERO,
1357 NONE,
1360 BRANCH_NO_COPY,
1361 BRANCH_NOC,
1367 channels,
1368 filter_width,
1369 filter_height,
1370 channels,
1374 nullptr,
1375 nullptr,
1376 PADDING_SAME_ZERO,
1377 NONE,
1380 BRANCH_NO_COPY,
1381 BRANCH_NOC,
1387 channels,
1388 filter_width,
1389 filter_height,
1390 channels,
1394 nullptr,
1395 nullptr,
1396 PADDING_SAME_ZERO,
1397 NONE,
1400 BRANCH_NO_COPY,
1401 BRANCH_CAT,
1403 0x00,
1405 0x02,
1411 channels + channels,
1412 filter_width,
1413 filter_height,
1418 nullptr,
1419 nullptr,
1420 PADDING_SAME_ZERO,
1421 NONE,
1424 BRANCH_NO_COPY,
1425 BRANCH_NOC,
1429 } } };
1431 // Weights and biases need to be specified separately because
1432 // of the offset.
1433 AssignLayerWeightsBiases(&cnn_config, weights, bias);
1435 CNN_THREAD_DATA thread_data = { 1, nullptr };
1437 RunCNNTest(image_width, image_height, input, expected, &cnn_config,
1438 image_width, &thread_data, MSE_INT_TOL);
1441 // TODO(logangw): Add test to test all combinations of branch_copy_type.
1443 TEST_F(CNNTest, TestBranchCombinations) {
1444 int filter_width = 2;
1445 int filter_height = 3;
1447 int image_width = 4;
1448 int image_height = 4;
1450 float input[] = {
1451 3, 2, -5, -4, 4, -2, -4, -3, 4, 2, -3, 2, -3, 1, -5, -1,
1454 float weights[] = {
1455 2, 3, 0, 4, 4, 3, 1, 0, 1, -5, 4, -3, 3, 0, 4, -1, -1, -5,
1456 2, 1, -3, -5, 3, -1, -3, -2, 0, -2, 3, 0, -2, -4, -2, -2, 2, -5,
1457 4, -5, 0, 1, -5, -4, -3, -4, 2, -2, 1, 0, 3, -2, -4, 3, 4, -4,
1458 -1, -1, -3, -2, -2, -1, 2, 0, 2, -1, 2, -4, -4, -1, 2, 0, 3, -2,
1459 -2, 3, -3, 4, -2, 4, 3, 4, 1, 0, -2, -3, -5, 1, -3, 2, 0, -2,
1460 -2, -1, -1, -5, -2, -3, -1, 3, 3, 4, 4, 0, 2, 1, 3, -3, 2, -5,
1461 -5, 1, -5, -1, 3, 3, 2, -4, -1, 3, -4, -2, -5, -2, 1, 3, 2, 2,
1462 -5, -2, -3, -1, -2, -4, -1, -2, 2, 1, -4, -4, 2, 0, 2, 0, 2, -3,
1463 -2, -4, 4, 0, 1, -3, -5, 4, -1, 2, 3, -5, -1, 0, 4, -1, -1, 3,
1464 -1, -3, 3, 1, 4, 3, 4, 3, -4, -5, -1, 3, 3, -4, 3, 1, 3, -5,
1465 3, 4, -5, 4, 2, -1, -5, 2, 1, 0, 4, 0, -3, 2, 0, 2, -2, 1,
1466 -1, -2, -1, -5, 4, 3, 3, -2, 2, 4, -5, -5, -3, -2, 4, 0, -4, 1,
1469 float bias[] = {
1470 -1, 4, 0, 2, 2, -2, 0, -4, -5, -1, 1, -2, 3, 0, 4, -2, 1, 0, 0,
1473 float expected[] = {
1474 149496, 15553, -24193, -20956, 134094, 86432, -68283, -6366,
1475 -53031, 133739, 67407, -13539, -53205, -58635, -20033, 1979,
1478 int channels = 2;
1480 CNN_CONFIG cnn_config = { 10,
1488 filter_width,
1489 filter_height,
1490 channels,
1494 weights,
1495 bias,
1496 PADDING_SAME_ZERO,
1497 NONE,
1500 BRANCH_NO_COPY,
1501 BRANCH_NOC,
1507 channels,
1508 filter_width,
1509 filter_height,
1510 channels,
1514 nullptr,
1515 nullptr,
1516 PADDING_SAME_ZERO,
1517 NONE,
1520 BRANCH_INPUT,
1521 BRANCH_NOC,
1523 0x06,
1525 0x00,
1531 channels,
1532 filter_width,
1533 filter_height,
1534 channels,
1538 nullptr,
1539 nullptr,
1540 PADDING_SAME_ZERO,
1541 NONE,
1544 BRANCH_OUTPUT,
1545 BRANCH_NOC,
1547 0x08,
1549 0x00,
1555 channels,
1556 filter_width,
1557 filter_height,
1558 channels,
1562 nullptr,
1563 nullptr,
1564 PADDING_SAME_ZERO,
1565 NONE,
1568 BRANCH_NO_COPY,
1569 BRANCH_NOC,
1575 channels,
1576 filter_width,
1577 filter_height,
1578 channels,
1582 nullptr,
1583 nullptr,
1584 PADDING_SAME_ZERO,
1585 NONE,
1588 BRANCH_NO_COPY,
1589 BRANCH_ADD,
1591 0x00,
1593 0x08,
1599 channels,
1600 filter_width,
1601 filter_height,
1602 channels,
1606 nullptr,
1607 nullptr,
1608 PADDING_SAME_ZERO,
1609 NONE,
1612 BRANCH_NO_COPY,
1613 BRANCH_NOC,
1619 channels,
1620 filter_width,
1621 filter_height,
1622 channels,
1626 nullptr,
1627 nullptr,
1628 PADDING_SAME_ZERO,
1629 NONE,
1632 BRANCH_NO_COPY,
1633 BRANCH_NOC,
1639 channels,
1640 filter_width,
1641 filter_height,
1642 channels,
1646 nullptr,
1647 nullptr,
1648 PADDING_SAME_ZERO,
1649 NONE,
1652 BRANCH_NO_COPY,
1653 BRANCH_ADD,
1655 0x00,
1657 0x0C,
1663 channels,
1664 filter_width,
1665 filter_height,
1666 channels,
1670 nullptr,
1671 nullptr,
1672 PADDING_SAME_ZERO,
1673 NONE,
1676 BRANCH_NO_COPY,
1677 BRANCH_ADD,
1679 0x00,
1681 0x02,
1687 channels,
1688 filter_width,
1689 filter_height,
1694 nullptr,
1695 nullptr,
1696 PADDING_SAME_ZERO,
1697 NONE,
1700 BRANCH_NO_COPY,
1701 BRANCH_NOC,
1706 } };
1708 // Weights and biases need to be specified separately because
1709 // of the offset.
1710 AssignLayerWeightsBiases(&cnn_config, weights, bias);
1712 CNN_THREAD_DATA thread_data = { 1, nullptr };
1714 RunCNNTest(image_width, image_height, input, expected, &cnn_config,
1715 image_width, &thread_data, MSE_INT_TOL);
1718 TEST_F(CNNTest, TestSplittingTensors) {
1719 int filter_width = 2;
1720 int filter_height = 3;
1722 int image_width = 4;
1723 int image_height = 4;
1725 float input[] = {
1726 -1, -1, 2, 1, 3, 2, 4, -3, -4, -2, 2, -3, 1, -3, 4, -2,
1729 float weights[] = {
1730 -4, 1, 0, 2, 3, 4, 4, -4, -5, -3, 2, 2, -4, -3, 3, 2,
1731 4, -4, -3, -4, -4, 1, -3, -5, -3, 4, 2, -2, 2, -1, -4, -1,
1732 -2, -3, 1, 1, 0, -5, -1, 3, 3, -5, -3, 0, -3, 1, -3, -1,
1733 1, -3, -2, -2, 4, -2, 0, 1, 2, 2, -4, 2, 4, 0, -5, -2,
1734 4, 4, -5, 1, 0, 2, -2, -5, -5, -3, -5, -5, 4, -3, 0, 0,
1735 -4, -4, 0, -5, -4, 0, 0, -3, -5, -3, -1, 2, -1, 4, -1, 2,
1738 float bias[] = {
1739 -4, -2, -3, -3, 3, 1, -2,
1742 float expected[] = {
1743 530, -762, 1469, 777, 849, -771, -1698, 600,
1744 -658, -1821, 98, -668, -1798, 30, 887, -971,
1747 CNN_CONFIG cnn_config = { 3,
1755 filter_width,
1756 filter_height,
1761 nullptr,
1762 nullptr,
1763 PADDING_SAME_ZERO,
1764 NONE,
1767 BRANCH_OUTPUT,
1768 BRANCH_NOC,
1770 0x02,
1772 0x00,
1779 filter_width,
1780 filter_height,
1785 nullptr,
1786 nullptr,
1787 PADDING_SAME_ZERO,
1788 NONE,
1791 BRANCH_NO_COPY,
1792 BRANCH_CAT,
1794 0x00,
1796 0x02,
1803 filter_width,
1804 filter_height,
1809 nullptr,
1810 nullptr,
1811 PADDING_SAME_ZERO,
1812 NONE,
1815 BRANCH_NO_COPY,
1816 BRANCH_NOC,
1821 } };
1823 // Weights and biases need to be specified separately because
1824 // of the offset.
1825 AssignLayerWeightsBiases(&cnn_config, weights, bias);
1827 CNN_THREAD_DATA thread_data = { 1, nullptr };
1829 RunCNNTest(image_width, image_height, input, expected, &cnn_config,
1830 image_width, &thread_data, MSE_INT_TOL);
1833 TEST_F(CNNTest, TestOutputChannelsCount) {
1834 int filter_width = 1;
1835 int filter_height = 1;
1837 int image_width = 2;
1838 int image_height = 2;
1840 float input[] = { 0, 0, 0, 0 };
1842 float weights[] = { 0, 0, 0, 0, 0, 0, 0, 0 };
1844 float bias[] = { 0, 0, 0, 0, 0, 0 };
1846 float expected[] = {
1847 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1848 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1851 CNN_CONFIG cnn_config = { 3,
1859 filter_width,
1860 filter_height,
1865 weights,
1866 bias,
1867 PADDING_SAME_ZERO,
1868 NONE,
1871 BRANCH_INPUT,
1872 BRANCH_NOC,
1874 0x06,
1876 0x00,
1883 filter_width,
1884 filter_height,
1889 weights,
1890 bias,
1891 PADDING_SAME_ZERO,
1892 NONE,
1895 BRANCH_NO_COPY,
1896 BRANCH_CAT,
1898 0x00,
1900 0x03,
1907 filter_width,
1908 filter_height,
1913 weights,
1914 bias,
1915 PADDING_SAME_ZERO,
1916 NONE,
1919 BRANCH_NO_COPY,
1920 BRANCH_CAT,
1922 0x00,
1924 0x04,
1929 } };
1931 // Weights and biases need to be specified separately because
1932 // of the offset.
1933 AssignLayerWeightsBiases(&cnn_config, weights, bias);
1935 CNN_THREAD_DATA thread_data = { 1, nullptr };
1937 RunCNNTest(image_width, image_height, input, expected, &cnn_config,
1938 image_width, &thread_data, MSE_FLOAT_TOL);
1941 TEST_F(CNNTest, TestBatchNorm) {
1942 int image_width = 28;
1943 int image_height = 28;
1944 int filter_height = 7;
1945 int filter_width = 7;
1946 float input[] = {
1947 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1948 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1949 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1950 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1951 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1952 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1953 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1954 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1955 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1956 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1957 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1958 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1959 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1960 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1961 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1962 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1963 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1964 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1965 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1966 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1967 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1968 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1969 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1970 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1971 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1972 0.0f, 0.0f, 0.0117647f, 0.0705882f, 0.0705882f, 0.0705882f,
1973 0.494118f, 0.533333f, 0.686275f, 0.101961f, 0.65098f, 1.0f,
1974 0.968627f, 0.498039f, 0.0f, 0.0f, 0.0f, 0.0f,
1975 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1976 0.0f, 0.0f, 0.117647f, 0.141176f, 0.368627f, 0.603922f,
1977 0.666667f, 0.992157f, 0.992157f, 0.992157f, 0.992157f, 0.992157f,
1978 0.882353f, 0.67451f, 0.992157f, 0.94902f, 0.764706f, 0.25098f,
1979 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1980 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.192157f,
1981 0.933333f, 0.992157f, 0.992157f, 0.992157f, 0.992157f, 0.992157f,
1982 0.992157f, 0.992157f, 0.992157f, 0.984314f, 0.364706f, 0.321569f,
1983 0.321569f, 0.219608f, 0.152941f, 0.0f, 0.0f, 0.0f,
1984 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1985 0.0f, 0.0f, 0.0f, 0.0705882f, 0.858824f, 0.992157f,
1986 0.992157f, 0.992157f, 0.992157f, 0.992157f, 0.776471f, 0.713725f,
1987 0.968627f, 0.945098f, 0.0f, 0.0f, 0.0f, 0.0f,
1988 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1989 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1990 0.0f, 0.0f, 0.313725f, 0.611765f, 0.419608f, 0.992157f,
1991 0.992157f, 0.803922f, 0.0431373f, 0.0f, 0.168627f, 0.603922f,
1992 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1993 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1994 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1995 0.0f, 0.054902f, 0.00392157f, 0.603922f, 0.992157f, 0.352941f,
1996 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1997 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1998 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
1999 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2000 0.0f, 0.545098f, 0.992157f, 0.745098f, 0.00784314f, 0.0f,
2001 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2002 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2003 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2004 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0431373f,
2005 0.745098f, 0.992157f, 0.27451f, 0.0f, 0.0f, 0.0f,
2006 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2007 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2008 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2009 0.0f, 0.0f, 0.0f, 0.0f, 0.137255f, 0.945098f,
2010 0.882353f, 0.627451f, 0.423529f, 0.00392157f, 0.0f, 0.0f,
2011 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2012 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2013 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2014 0.0f, 0.0f, 0.0f, 0.317647f, 0.941176f, 0.992157f,
2015 0.992157f, 0.466667f, 0.0980392f, 0.0f, 0.0f, 0.0f,
2016 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2017 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2018 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2019 0.0f, 0.0f, 0.176471f, 0.729412f, 0.992157f, 0.992157f,
2020 0.588235f, 0.105882f, 0.0f, 0.0f, 0.0f, 0.0f,
2021 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2022 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2023 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2024 0.0f, 0.0627451f, 0.364706f, 0.988235f, 0.992157f, 0.733333f,
2025 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2026 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2027 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2028 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2029 0.0f, 0.976471f, 0.992157f, 0.976471f, 0.25098f, 0.0f,
2030 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2031 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2032 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2033 0.0f, 0.0f, 0.180392f, 0.509804f, 0.717647f, 0.992157f,
2034 0.992157f, 0.811765f, 0.00784314f, 0.0f, 0.0f, 0.0f,
2035 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2036 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2037 0.0f, 0.0f, 0.0f, 0.0f, 0.152941f, 0.580392f,
2038 0.898039f, 0.992157f, 0.992157f, 0.992157f, 0.980392f, 0.713725f,
2039 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2040 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2041 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2042 0.0941176f, 0.447059f, 0.866667f, 0.992157f, 0.992157f, 0.992157f,
2043 0.992157f, 0.788235f, 0.305882f, 0.0f, 0.0f, 0.0f,
2044 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2045 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2046 0.0f, 0.0f, 0.0901961f, 0.258824f, 0.835294f, 0.992157f,
2047 0.992157f, 0.992157f, 0.992157f, 0.776471f, 0.317647f, 0.00784314f,
2048 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2049 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2050 0.0f, 0.0f, 0.0f, 0.0f, 0.0705882f, 0.670588f,
2051 0.858824f, 0.992157f, 0.992157f, 0.992157f, 0.992157f, 0.764706f,
2052 0.313725f, 0.0352941f, 0.0f, 0.0f, 0.0f, 0.0f,
2053 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2054 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2055 0.215686f, 0.67451f, 0.886275f, 0.992157f, 0.992157f, 0.992157f,
2056 0.992157f, 0.956863f, 0.521569f, 0.0431373f, 0.0f, 0.0f,
2057 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2058 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2059 0.0f, 0.0f, 0.0f, 0.0f, 0.533333f, 0.992157f,
2060 0.992157f, 0.992157f, 0.831373f, 0.529412f, 0.517647f, 0.0627451f,
2061 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2062 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2063 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2064 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2065 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2066 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2067 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2068 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2069 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2070 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2071 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2072 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2073 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2074 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2075 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2076 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
2077 0.0f, 0.0f, 0.0f, 0.0f
2079 float expected[] = {
2080 -0.836424f, -0.857365f, -1.62739f, -1.62739f, -0.836424f, 5.40742f,
2081 0.920853f, -0.692567f, -0.836424f, -0.534405f, -1.62739f, -0.836424f,
2082 1.32602f, 1.36312f, 0.112766f, -0.836424f, -0.192962f, 1.56975f,
2083 2.45777f, 0.944414f, -0.192962f, -1.5519f, -1.5519f, -0.554006f,
2084 -0.192962f, 1.4231f, -1.5519f, -0.192962f, 1.3661f, -1.5519f,
2085 -1.5519f, -0.192962f, -0.843708f, -0.359025f, -0.843708f, -0.843708f,
2086 -0.843708f, 4.53065f, 0.0429584f, -0.796804f, -0.843708f, 0.3473f,
2087 -0.843708f, -0.843708f, -0.114439f, 3.14817f, 0.0811934f, -0.843708f
2089 float kernel[] = {
2090 0.119643f, -0.237864f, 0.0462892f, 0.0502297f, -0.0134528f,
2091 0.146347f, 0.153133f, 0.0513307f, 0.0752369f, 0.0135557f,
2092 -0.111434f, 0.0941854f, 0.0788362f, 0.0299412f, 0.111762f,
2093 0.144066f, 0.00431504f, -0.0177954f, 0.0738092f, -0.0344215f,
2094 0.0832582f, 0.053989f, -0.112691f, 0.0962145f, 0.0186525f,
2095 -0.00660205f, -0.111962f, -0.126801f, -0.231625f, 0.17309f,
2096 0.0748875f, -0.179569f, -0.00513812f, -0.156579f, -0.147322f,
2097 0.184168f, 0.189308f, -0.200359f, -0.0156733f, 0.140649f,
2098 0.0858496f, -0.0263217f, -0.0740749f, -0.112563f, 0.107528f,
2099 0.0609729f, -0.221625f, 0.0769944f, -0.00900815f, -0.00136441f,
2100 -0.0236521f, -0.0418025f, -0.00286299f, 0.12241f, 0.0964093f,
2101 -0.0150897f, 0.0532171f, 0.0625916f, 0.116939f, 0.118024f,
2102 0.161918f, -0.00909767f, 0.100897f, -0.054563f, -0.175179f,
2103 -0.0687892f, 0.00734235f, 0.109833f, -0.113776f, 0.0595405f,
2104 -0.170255f, 0.0124815f, -0.0363301f, -0.0127038f, 0.0445554f,
2105 -0.0729894f, 0.107428f, -0.0341417f, 0.132619f, 0.00984557f,
2106 -0.00443654f, 0.202929f, 0.0945134f, 0.0148725f, 0.00998574f,
2107 -0.0226449f, 0.0478197f, -0.0793442f, 0.0707599f, -0.084225f,
2108 0.0865795f, 0.071104f, -0.047894f, 0.0838322f, 0.0635493f,
2109 -0.00370265f, -0.157247f, -0.0289622f, -0.0590963f, 0.13207f,
2110 0.00468011f, -0.0345372f, 0.217939f, 0.18861f, -0.0290393f,
2111 -0.0440664f, 0.0126197f, -0.129132f, -0.124943f, 0.0968156f,
2112 -0.0853643f, -0.182305f, 0.00461618f, -0.147095f, -0.230282f,
2113 0.00856019f, 0.0278893f, -0.0300229f, 0.0417871f, 0.0804717f,
2114 -0.0768571f, -0.0397085f, -0.0601096f, 0.100901f, -0.0184926f,
2115 0.0350673f, 0.0971094f, -0.0171837f, -0.289644f, -0.0899041f,
2116 0.08998f, -0.160319f, -0.0195103f, 0.0392167f, -0.137864f,
2117 -0.0136294f, 0.0330886f, -0.0409244f, -0.092533f, -0.0427934f,
2118 -0.191144f, -0.0969461f, 0.112035f, 0.138611f, 0.128717f,
2119 0.191184f, 0.197462f
2121 float bias[] = { 0.186703f, 0.204358f, -0.0230452f };
2123 float bn_gamma[] = { 1.32173f, 1.26171f, 1.21966f };
2124 float bn_beta[] = { -0.232595f, -0.222652f, -0.232209f };
2125 float bn_mean[] = { 0.329233f, 0.199894f, 0.12389f };
2126 float bn_std[] = { 0.311986f, 0.189737f, 0.247104f };
2128 CNN_BATCHNORM_PARAMS bn_params = {
2129 bn_gamma,
2130 bn_beta,
2131 bn_mean,
2132 bn_std,
2135 CNN_CONFIG cnn_config = {
2144 filter_width,
2145 filter_height,
2150 kernel,
2151 bias,
2152 PADDING_VALID,
2153 RELU,
2156 BRANCH_NO_COPY,
2157 BRANCH_NOC,
2159 bn_params,
2165 CNN_THREAD_DATA thread_data = { 1, nullptr };
2167 RunCNNTest(image_width, image_height, input, expected, &cnn_config,
2168 image_width, &thread_data, MSE_FLOAT_TOL);
2171 TEST_F(CNNTest, TestMultithreading) {
2172 int image_height = 2;
2173 int image_width = 2;
2174 int filter_height = 3;
2175 int filter_width = 3;
2177 float input[] = {
2184 float weights[] = {
2185 -4, 2, -2, 0, -4, 4, -3, -3, -3, -1, 1, 0, -5, -3, 0, -5, 0, 0,
2186 -1, 0, 2, -5, 0, 1, 4, 2, 1, 0, -2, -1, -5, -3, 2, -2, 1, -5,
2189 float bias[] = {
2196 float expected[] = {
2197 2, 10, -8, -17, -24, 5, -15, 6, -5, -5, 7, -10, 4, 13, 9, -14,
2200 CNN_CONFIG cnn_config = {
2209 filter_width,
2210 filter_height,
2215 weights,
2216 bias,
2217 PADDING_SAME_ZERO,
2218 NONE,
2221 BRANCH_NO_COPY,
2222 BRANCH_NOC,
2230 CNN_THREAD_DATA thread_data = { 1, nullptr };
2232 RunCNNTest(image_width, image_height, input, expected, &cnn_config,
2233 image_width, &thread_data, MSE_FLOAT_TOL);
2235 const AVxWorkerInterface *const winterface = aom_get_worker_interface();
2236 AVxWorker workers[4];
2238 for (int i = 0; i < 4; ++i) {
2239 winterface->init(&workers[i]);
2242 thread_data = { 4, workers };
2244 RunCNNTest(image_width, image_height, input, expected, &cnn_config,
2245 image_width, &thread_data, MSE_FLOAT_TOL);
2247 for (int i = 0; i < 4; ++i) {
2248 winterface->end(&workers[i]);
2252 TEST_F(CNNTest, TestMultiOutput) {
2253 const int image_dim = 8;
2254 const int image_ch = 3;
2255 const int filter_dim = 2;
2256 const int stride = 2;
2257 const int num_filters = 2;
2259 const float input_[] = {
2260 1.7537929121f, 0.134331551012f, 0.123580039877f, 0.957731845246f,
2261 0.391006834217f, 1.00699352042f, -0.778177955829f, -0.814166433059f,
2262 -0.656374394915f, 0.321967305228f, -2.19455719176f, 0.708035038966f,
2263 0.409148822266f, -0.318254408902f, 0.152450211189f, -0.250210793369f,
2264 0.826811563186f, 1.6804156584f, 0.273626975978f, 0.437936241887f,
2265 -0.329935520167f, -0.288761611645f, 0.156937008304f, 0.271054157295f,
2266 -0.0224828854332f, 1.70110336895f, -0.989066699309f, 1.30863131729f,
2267 -0.165813705702f, 0.00380178619265f, -0.0837342367587f, 0.760954783156f,
2268 -0.413610373524f, 1.17968204175f, 0.720295719536f, 0.308718974472f,
2269 -1.10091337671f, 0.693160033687f, -0.0202862320697f, 1.0221927503f,
2270 -1.24521801881f, -0.478501952308f, -1.71648619442f, -0.182571723636f,
2271 0.339292649504f, 2.0806519131f, 0.967974033444f, 0.175248672328f,
2272 0.0658124561472f, 0.795504169496f, 0.750592557361f, -1.46631013249f,
2273 -1.79052846838f, -1.03672179515f, -0.841985521653f, 1.20995011489f,
2274 0.140859718215f, -0.651552622661f, 0.451065110806f, 1.1189443693f,
2275 0.100213260593f, -0.834076868118f, -1.28734321611f, 1.22064420095f,
2276 -0.364143084361f, 0.750961509335f, -0.888689074553f, -0.8253547106f,
2277 -1.21800999027f, -0.966670603566f, 1.37384014741f, 0.47281264834f,
2278 -0.420416235531f, 0.520163906493f, 0.501296589423f, 1.53418976951f,
2279 0.715234751485f, 0.644551588907f, 0.0763504863375f, -0.0018541943723f,
2280 0.322853189656f, -0.795099723224f, -0.125177096675f, 1.4476577471f,
2281 -0.585888410088f, -1.44391754955f, -0.610543221933f, -0.221859179799f,
2282 0.252060200774f, -0.86287169623f, -0.0350246229157f, 1.0932311997f,
2283 0.899464648842f, -0.468806951704f, -0.300861137168f, 1.15776414206f,
2284 1.03268544738f, -0.171579585622f, -0.179136557119f, -0.354091003368f,
2285 -0.612298249394f, -1.20237379258f, 1.54604109659f, 0.130664370287f,
2286 0.885225111868f, 1.0362799581f, 0.980561720868f, -0.619379186999f,
2287 -1.33818929924f, -0.237233737961f, -1.89335425073f, 0.567821011321f,
2288 0.862420368465f, -1.37380916821f, 0.352190056666f, 0.611261516274f,
2289 0.393237747152f, 0.894686247967f, 0.190405182149f, 0.264872662911f,
2290 -0.0657009133797f, 0.0580512653493f, -0.401825294366f, 0.4106081318f,
2291 0.49484512188f, -0.0751103149442f, -1.43243736382f, 1.79855656009f,
2292 -1.1075351975f, 0.000354882733011f, -0.950716438608f, 1.27129831688f,
2293 1.00495189838f, 0.110358656713f, 1.08315032822f, -0.972676676218f,
2294 -0.0757668962831f, 1.88932045165f, -0.0672638136275f, 0.425913010161f,
2295 -0.781540372017f, 0.976000248609f, 0.687218504122f, 1.31374513445f,
2296 -0.932658930672f, -1.25339468479f, 0.422071294078f, -0.24189927912f,
2297 0.216906604642f, -1.88720997548f, 1.99252872889f, 0.353943735777f,
2298 0.737434784132f, -1.17848645017f, 1.70424254896f, 0.775297112968f,
2299 -0.516392797501f, 0.398130609129f, 0.737248101457f, 0.166282500886f,
2300 1.24699015468f, 0.47116183125f, 1.19091180182f, -0.372695424578f,
2301 0.219773209389f, -0.829467838962f, -0.52533122724f, 1.98707754595f,
2302 0.553692606972f, -0.933228902369f, 1.55427751643f, -1.08813399144f,
2303 -0.325686682094f, 0.205091443796f, -1.70381666435f, 0.466465327942f,
2304 1.73126863447f, -0.939133672634f, 1.48318077459f, -0.599414038168f,
2305 -1.1583078687f, 0.518116190201f, 0.133571482458f, 0.84958342672f,
2306 1.02205000597f, -0.0772082009087f, -1.69567503859f, 1.4697939436f,
2307 1.67813743122f, -0.627911582938f, 0.131380509137f, -1.35717850726f,
2309 const float *input[3] = { input_, &input_[image_dim * image_dim],
2310 &input_[2 * image_dim * image_dim] };
2312 const float bias[] = { 0.0f, 0.0f };
2314 const float weights_1[] = {
2315 -0.489547413618f, 0.141916424749f, -0.279286485585f, -0.115322211094f,
2316 0.299572786936f, 0.205289980785f, -0.536254480088f, -0.253626313744f,
2317 -0.422883815849f, -0.169702966298f, -0.540104704793f, 0.495319646763f,
2318 0.298799079422f, -0.10054550901f, -0.306085047056f, 0.171061886165f,
2319 -0.108058703878f, -0.410734629888f, -0.0640674673049f, -0.386524840979f,
2320 -0.157203423678f, -0.362138920529f, -0.216206085209f, 0.147502517971f,
2323 const float weights_2[] = {
2324 0.207580604357f, 0.480821146263f, -0.29111909562f, 0.47422567493f,
2325 0.206892553253f, -0.235067084092f, 0.354516800602f, -0.212399370252f,
2326 -0.419071343731f, -0.050350731631f, -0.0516457320279f, -0.0359310500731f,
2327 0.567044864811f, -0.060341127522f, 0.0501464839637f, -0.437785677916f,
2330 const float weights_3[] = {
2331 -0.0690452401448f, -0.356657338763f, -0.219464031809f, 0.551288365843f,
2332 0.181372090853f, -0.00245268542109f, 0.409000696276f, -0.593209108763f,
2333 0.587352566749f, -0.243720660227f, 0.266232713887f, -0.00439285245097f,
2334 0.252883228305f, 0.152646192631f, 0.0918944932026f, 0.398853715057f,
2337 const float weights_4[] = {
2338 0.207560791573f, 0.194201350401f, 0.227802322443f, 0.206533663345f,
2339 0.0557331066805f, 0.0224159800424f, -0.143939197467f, -0.27703361602f,
2340 0.130643888389f, -0.269456557461f, 0.186242862864f, -0.162879944774f,
2341 -0.145503996718f, -0.0768822987581f, -0.203127976359f, -0.238119922873f,
2342 -0.258806479994f, 0.0357957680385f, -0.1027606976f, -0.287920082345f,
2343 0.189047820993f, 0.250711538481f, -0.272815714175f, -0.0431449742024f,
2344 0.207261230996f, -0.0396472677451f, 0.131236557412f, 0.174291832499f,
2345 -0.251515885765f, -0.107164007499f, 0.185824534748f, -0.00561585838161f,
2346 0.273393799578f, -0.139563699075f, -0.263922456031f, -0.118859844081f,
2347 0.109230982597f, -0.170170294794f, 0.0123025648515f, -0.0839368964355f,
2348 -0.0774058234297f, 0.255847138286f, -0.208430879637f, 0.279170114319f,
2349 -0.272890330712f, -0.217725903006f, -0.295923275459f, -0.17008723953f,
2350 -0.284281803405f, 0.281406323629f, 0.266910044663f, -0.209963914338f,
2351 0.271980962964f, 0.142013581699f, -0.143896509026f, -0.290509242975f,
2352 -0.305768180935f, 0.196902832117f, -0.090424189662f, -0.147460802346f,
2353 0.217722016651f, 0.12353848977f, -0.169177363577f, -0.0454230918512f,
2356 const float expected_0[] = {
2357 -2.04858441055f, -2.12883075791f, -0.045177363807f, 0.763949675768f,
2358 -0.544361512821f, -1.58123168032f, 1.89319847039f, 0.16859080901f,
2359 -1.16023321135f, -0.396988107751f, 1.76637090744f, -1.40434786514f,
2360 0.908227575669f, 0.817064817605f, 0.215631134908f, -0.848605613428f,
2361 -0.106756747018f, 0.0193027166685f, 0.801345615113f, -0.395407237598f,
2362 -1.79983795658f, -1.73054496242f, 0.0584392594454f, -0.388786095569f,
2363 -0.237269619354f, 0.000843578271263f, -1.24043512104f, 0.487839445893f,
2364 -0.394259726605f, 0.559632843424f, -0.527224052291f, -1.53792340282f,
2367 const float expected_1[] = {
2368 0.0f, 0.0f, 0.0f, 0.0f, 0.4057888292f, 0.325309571755f,
2369 0.0f, 1.22013465602f,
2372 const float expected_2[] = {
2373 0.156119444687f,
2374 0.517385299817f,
2377 const float expected_3[] = {
2378 0.224177852984f,
2379 0.503384419034f,
2380 0.156119444687f,
2381 0.517385299817f,
2384 const float *expected[] = { expected_0, expected_1, expected_2, expected_3 };
2386 CNN_CONFIG cnn_config = {
2387 4, // num_layers
2388 0, // is_residue
2389 0, // ext_width
2390 0, // ext_height
2391 0, // strict_bounds
2393 // layer_config
2395 image_ch, // in_channels
2396 filter_dim, // filter_width
2397 filter_dim, // filter_height
2398 num_filters, // out_channels
2399 stride, // skip_width
2400 stride, // skip_height
2401 0, // max_pool
2402 weights_1, // weights
2403 bias, // bias
2404 PADDING_SAME_ZERO, // pad
2405 NONE, // activation
2406 0, // deconvolve
2407 0, // branch
2408 BRANCH_OUTPUT, // branch_copy_type
2409 BRANCH_NOC, // branch_combine_type
2410 { 2, 0, 0 }, // branch_config
2411 {}, // bn_params
2412 0, // output_num
2415 num_filters, // in_channels
2416 filter_dim, // filter_width
2417 filter_dim, // filter_height
2418 num_filters, // out_channels
2419 stride, // skip_width
2420 stride, // skip_height
2421 0, // max_pool
2422 weights_2, // weights
2423 bias, // bias
2424 PADDING_SAME_ZERO, // pad
2425 RELU, // activation
2426 0, // deconvolve
2427 0, // branch
2428 BRANCH_NO_COPY, // branch_copy_type
2429 BRANCH_NOC, // branch_combine_type
2430 {}, // branch_config
2431 {}, // bn_params
2432 1, // output_num
2435 num_filters, // in_channels
2436 filter_dim, // filter_width
2437 filter_dim, // filter_height
2438 num_filters, // out_channels
2439 stride, // skip_width
2440 stride, // skip_height
2441 0, // max_pool
2442 weights_3, // weights
2443 bias, // bias
2444 PADDING_SAME_ZERO, // pad
2445 RELU, // activation
2446 0, // deconvolve
2447 0, // branch
2448 BRANCH_NO_COPY, // branch_copy_type
2449 BRANCH_NOC, // branch_combine_type
2450 {}, // branch_config
2451 {}, // bn_params
2452 2, // output_num
2455 num_filters, // in_channels
2456 2 * filter_dim, // filter_width
2457 2 * filter_dim, // filter_height
2458 num_filters, // out_channels
2459 2 * stride, // skip_width
2460 2 * stride, // skip_height
2461 0, // max_pool
2462 weights_4, // weights
2463 bias, // bias
2464 PADDING_VALID, // pad
2465 RELU, // activation
2466 0, // deconvolve
2467 1, // branch
2468 BRANCH_NO_COPY, // branch_copy_type
2469 BRANCH_CAT, // branch_combine_type
2470 { 0, 0, 1 }, // branch_config
2471 {}, // bn_params
2472 3, // output_num
2477 CNN_THREAD_DATA thread_data = { 1, nullptr };
2479 const int num_outputs = 4;
2480 const int output_chs[4] = { filter_dim, filter_dim, filter_dim,
2481 2 * filter_dim };
2482 const int output_dims[4] = { 4, 2, 1, 1 };
2483 const int output_sizes[4] = {
2484 output_chs[0] * output_dims[0] * output_dims[0],
2485 output_chs[1] * output_dims[1] * output_dims[1],
2486 output_chs[2] * output_dims[2] * output_dims[2],
2487 output_chs[3] * output_dims[3] * output_dims[3],
2489 float *const output_ = (float *)aom_malloc(
2490 sizeof(*output_) *
2491 (output_sizes[0] + output_sizes[1] + output_sizes[2] + output_sizes[3]));
2492 ASSERT_NE(output_, nullptr);
2493 float *output[CNN_MAX_CHANNELS] = { nullptr };
2494 int ch_ite = 0;
2495 float *output_ite = output_;
2496 for (int output_idx = 0; output_idx < num_outputs; output_idx++) {
2497 for (int channel = 0; channel < output_chs[output_idx]; ++channel) {
2498 output[ch_ite++] = output_ite;
2499 output_ite += output_dims[output_idx] * output_dims[output_idx];
2502 CNN_MULTI_OUT output_struct = { num_outputs, output_chs, output_dims,
2503 output };
2505 RunMultiOutCNNTest(input, image_dim, image_dim, image_dim, &cnn_config,
2506 &thread_data, &output_struct, expected, MSE_FLOAT_TOL);
2508 aom_free(output_);
2511 namespace {
2513 typedef void (*CNNConvolveNoMaxpoolPaddingValidFunc)(
2514 const float **input, int in_width, int in_height, int in_stride,
2515 const CNN_LAYER_CONFIG *layer_config, float **output, int out_stride,
2516 int start_idx, int cstep, int channel_step);
2518 typedef libaom_test::FuncParam<CNNConvolveNoMaxpoolPaddingValidFunc>
2519 CNNConvolveTestFuncs;
2521 class CNNConvolveTest : public ::testing::TestWithParam<CNNConvolveTestFuncs> {
2522 protected:
2523 void SetUp() override { params_ = GetParam(); }
2525 void RunCNNConvolveSetup(int run_times) {
2526 int in_width = 65;
2527 int in_height = 65;
2529 const CNN_CONFIG *cnn_config = &av1_intra_mode_cnn_partition_cnn_config;
2531 for (int layer = 0; layer < cnn_config->num_layers; ++layer) {
2532 int out_width = 0, out_height = 0;
2533 int in_size = in_width * in_height;
2534 // Get current layer output width and height.
2535 av1_find_cnn_layer_output_size(in_height, in_width,
2536 &cnn_config->layer_config[layer],
2537 &out_width, &out_height);
2539 int out_size = out_width * out_height;
2540 float *input[20], *output_ref[20], *output_mod[20];
2542 float *input_data =
2543 (float *)aom_malloc(sizeof(*input_data) * in_size *
2544 cnn_config->layer_config[layer].in_channels);
2545 float *temp_ptr = input_data;
2546 ASSERT_NE(temp_ptr, nullptr);
2547 for (int i = 0; i < cnn_config->layer_config[layer].in_channels; ++i) {
2548 input[i] = temp_ptr;
2549 for (int j = 0; j < in_size; j++) {
2550 *(temp_ptr++) = ((float)rng_.Rand31() - (1 << 30)) / (1u << 31);
2554 float *out_data_ref = (float *)aom_calloc(
2555 sizeof(*out_data_ref),
2556 out_size * cnn_config->layer_config[layer].out_channels);
2557 ASSERT_NE(out_data_ref, nullptr);
2558 float *out_data_mod = (float *)aom_calloc(
2559 sizeof(*out_data_mod),
2560 out_size * cnn_config->layer_config[layer].out_channels);
2561 ASSERT_NE(out_data_mod, nullptr);
2562 float *temp_ptr1 = out_data_ref;
2563 float *temp_ptr2 = out_data_mod;
2564 for (int i = 0; i < cnn_config->layer_config[layer].out_channels; ++i) {
2565 output_ref[i] = temp_ptr1;
2566 output_mod[i] = temp_ptr2;
2567 temp_ptr1 += out_size;
2568 temp_ptr2 += out_size;
2571 RunCNNConvolveTest(input, in_width, in_height, out_size,
2572 &cnn_config->layer_config[layer], 0, 1, run_times,
2573 layer, output_ref, output_mod, out_width);
2575 // Set current layer output width and height as next layer input width and
2576 // height.
2577 in_width = out_width;
2578 in_height = out_height;
2580 aom_free(input_data);
2581 aom_free(out_data_ref);
2582 aom_free(out_data_mod);
2586 void RunCNNConvolveTest(float **input, int in_width, int in_height,
2587 int out_size, const CNN_LAYER_CONFIG *layer_config,
2588 int start_idx, int step, int run_times, int layer,
2589 float **output_ref, float **output_mod,
2590 int out_stride) {
2591 const int cstep = layer_config->in_channels * layer_config->out_channels;
2592 const int channel_step = AOMMAX(step, 1);
2593 aom_usec_timer timer;
2594 aom_usec_timer_start(&timer);
2595 for (int i = 0; i < run_times; ++i) {
2596 params_.ref_func((const float **)input, in_width, in_height, in_width,
2597 layer_config, output_ref, out_stride, start_idx, cstep,
2598 channel_step);
2600 aom_usec_timer_mark(&timer);
2601 const double time1 = static_cast<double>(aom_usec_timer_elapsed(&timer));
2603 aom_usec_timer_start(&timer);
2604 for (int i = 0; i < run_times; ++i) {
2605 params_.tst_func((const float **)input, in_width, in_height, in_width,
2606 layer_config, output_mod, out_stride, start_idx, cstep,
2607 channel_step);
2609 aom_usec_timer_mark(&timer);
2610 const double time2 = static_cast<double>(aom_usec_timer_elapsed(&timer));
2612 if (run_times > 1) {
2613 printf("layer : %d \n", layer);
2614 printf("%7.2f/%7.2fns (%3.2f)\n", time1, time2, time1 / time2);
2615 } else {
2616 for (int channel = 0; channel < layer_config->out_channels; ++channel) {
2617 const float *buf_ref = output_ref[channel];
2618 const float *buf_mod = output_mod[channel];
2620 for (int i = 0; i < out_size; ++i) {
2621 if (buf_ref[i] < CNN_CONVOLVE_PIXELWISE_FLOAT_TOL) {
2622 ASSERT_LE(buf_ref[i], CNN_CONVOLVE_PIXELWISE_FLOAT_TOL)
2623 << "Reference output was near-zero, test output was not ("
2624 << buf_mod[i] << ")";
2625 } else {
2626 const float error = buf_ref[i] - buf_mod[i];
2627 const float relative_error = fabsf(error / buf_ref[i]);
2628 ASSERT_LE(relative_error, CNN_CONVOLVE_PIXELWISE_FLOAT_TOL)
2629 << " channel " << channel << " pixel " << i << ": "
2630 << buf_ref[i] << "/" << buf_mod[i] << std::endl;
2637 private:
2638 CNNConvolveTestFuncs params_;
2639 libaom_test::ACMRandom rng_;
2641 GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(CNNConvolveTest);
2643 TEST_P(CNNConvolveTest, CheckOutput) { RunCNNConvolveSetup(1); }
2645 TEST_P(CNNConvolveTest, DISABLED_Speed) { RunCNNConvolveSetup(100000); }
2647 #if HAVE_AVX2 && !CONFIG_EXCLUDE_SIMD_MISMATCH
2648 INSTANTIATE_TEST_SUITE_P(AVX2, CNNConvolveTest,
2649 ::testing::Values(CNNConvolveTestFuncs(
2650 &av1_cnn_convolve_no_maxpool_padding_valid_c,
2651 &av1_cnn_convolve_no_maxpool_padding_valid_avx2)));
2652 #endif
2654 #if HAVE_NEON
2655 INSTANTIATE_TEST_SUITE_P(NEON, CNNConvolveTest,
2656 ::testing::Values(CNNConvolveTestFuncs(
2657 &av1_cnn_convolve_no_maxpool_padding_valid_c,
2658 &av1_cnn_convolve_no_maxpool_padding_valid_neon)));
2659 #endif
2661 } // namespace