no bug - Bumping Firefox l10n changesets r=release a=l10n-bump DONTBUILD CLOSED TREE
[gecko.git] / third_party / gemmology / gemmology_fwd.h
blob83e3719b4fb39fc177780c2903c828ff156b45ae
1 /***************************************************************
2 * _ *
3 * | | *
4 * __ _ ___ _ __ ___ _ __ ___ ___ | | ___ __ _ _ _ *
5 * / _` |/ _ \ '_ ` _ \| '_ ` _ \ / _ \| |/ _ \ / _` | | | | *
6 * | (_| | __/ | | | | | | | | | | (_) | | (_) | (_| | |_| | *
7 * \__, |\___|_| |_| |_|_| |_| |_|\___/|_|\___/ \__, |\__, | *
8 * __/ | __/ | __/ | *
9 * |___/ |___/ |___/ *
10 * *
11 * version 0.1 *
12 ***************************************************************/
14 #ifndef GEMMOLOGY_FWD_H
15 #define GEMMOLOGY_FWD_H
17 #include <cstdint>
18 #include <cstring>
19 #include <tuple>
20 #include <xsimd/xsimd.hpp>
22 namespace gemmology {
24 namespace callbacks {
26 struct Unquantize {
27 float unquant_mult;
28 template <class Arch>
29 xsimd::batch<float, Arch> operator()(xsimd::batch<int32_t, Arch> total, size_t, size_t, size_t);
30 template <class Arch>
31 std::tuple<xsimd::batch<float, Arch>, xsimd::batch<float, Arch>> operator()(
32 std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>>
33 total,
34 size_t, size_t, size_t);
37 struct AddBias {
38 const float *bias_addr;
39 template <class Arch>
40 xsimd::batch<float, Arch> operator()(xsimd::batch<float, Arch> total, size_t, size_t col_idx,
41 size_t);
42 template <class Arch>
43 std::tuple<xsimd::batch<float, Arch>, xsimd::batch<float, Arch>>
44 operator()(
45 std::tuple<xsimd::batch<float, Arch>, xsimd::batch<float, Arch>> total,
46 size_t, size_t col_idx, size_t);
49 struct Write {
50 float *output_addr;
52 Write(float *o) : output_addr(o) {}
54 template <class Arch>
55 void operator()(xsimd::batch<float, Arch> result, size_t row_idx,
56 size_t col_idx, size_t col_size);
57 template <class Arch>
58 void operator()(xsimd::batch<int32_t, Arch> result, size_t row_idx,
59 size_t col_idx, size_t col_size);
61 template <class Arch>
62 void operator()(
63 std::tuple<xsimd::batch<float, Arch>, xsimd::batch<float, Arch>> result,
64 size_t row_idx, size_t col_idx, size_t col_size);
66 template <class Arch>
67 void operator()(
68 std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>>
69 result,
70 size_t row_idx, size_t col_idx, size_t col_size);
73 struct UnquantizeAndWrite {
75 Unquantize unquantize;
76 Write write;
78 UnquantizeAndWrite(float factor, float *output)
79 : unquantize{factor}, write{output} {}
81 template <class T>
82 void operator()(T const &total, size_t row_idx, size_t col_idx,
83 size_t col_size);
86 struct UnquantizeAndAddBiasAndWrite {
88 Unquantize unquantize;
89 AddBias add_bias;
90 Write write;
92 UnquantizeAndAddBiasAndWrite(float factor, const float *bias, float *output)
93 : unquantize{factor}, add_bias{bias}, write{output} {}
95 template <class T>
96 void operator()(T const &total, size_t row_idx, size_t col_idx,
97 size_t col_size);
100 } // namespace callbacks
103 // Arch-specific implementation of each routine
105 template <class Arch> struct Engine {
107 static void QuantizeU(const float *input, uint8_t *output, float quant_mult,
108 size_t size);
110 static void Quantize(const float *const input, int8_t *const output,
111 float quant_mult, size_t size);
113 template <typename IntegerTy>
114 static void SelectColumnsB(const int8_t *input, int8_t *output, size_t rows,
115 const IntegerTy *cols_begin,
116 const IntegerTy *cols_end);
118 static void PrepareBTransposed(const float *input, int8_t *output,
119 float quant_mult, size_t cols, size_t rows);
121 static void PrepareBQuantizedTransposed(const int8_t *input, int8_t *output,
122 size_t cols, size_t rows);
124 static void PrepareB(const float *input, int8_t *output_shadow,
125 float quant_mult, size_t rows, size_t cols);
127 static void PrepareA(const float *input, int8_t *output, float quant_mult,
128 size_t rows, size_t cols);
130 struct Shift {
132 static void PrepareA(const float *input, uint8_t *output, float quant_mult,
133 size_t rows, size_t cols);
135 template <class Callback>
136 static void Multiply(const uint8_t *A, const int8_t *B, size_t A_rows,
137 size_t width, size_t B_cols, Callback callback);
139 template <class Callback>
140 static void PrepareBias(const int8_t *B, size_t width, size_t B_cols,
141 Callback C);
146 // Top-level wrappers that mostly match intgemm API
149 template <class Arch = xsimd::default_arch>
150 inline void QuantizeU(const float *input, uint8_t *output, float quant_mult,
151 size_t size) {
152 return Engine<Arch>::QuantizeU(input, output, quant_mult, size);
155 template <class Arch = xsimd::default_arch>
156 inline void Quantize(const float *const input, int8_t *const output,
157 float quant_mult, size_t size) {
158 return Engine<Arch>::Quantize(input, output, quant_mult, size);
161 template <class Arch = xsimd::default_arch, typename IntegerTy>
162 inline void SelectColumnsB(const int8_t *input, int8_t *output, size_t rows,
163 const IntegerTy *cols_begin,
164 const IntegerTy *cols_end) {
165 return Engine<Arch>::SelectColumnsB(input, output, rows, cols_begin,
166 cols_end);
169 template <class Arch = xsimd::default_arch>
170 inline void PrepareBTransposed(const float *input, int8_t *output,
171 float quant_mult, size_t cols, size_t rows) {
172 return Engine<Arch>::PrepareBTransposed(input, output, quant_mult, cols,
173 rows);
176 template <class Arch = xsimd::default_arch>
177 inline void PrepareBQuantizedTransposed(const int8_t *input, int8_t *output,
178 size_t cols, size_t rows) {
179 return Engine<Arch>::PrepareBQuantizedTransposed(input, output, cols, rows);
182 template <class Arch = xsimd::default_arch>
183 inline void PrepareB(const float *input, int8_t *output_shadow,
184 float quant_mult, size_t rows, size_t cols) {
185 return Engine<Arch>::PrepareB(input, output_shadow, quant_mult, rows, cols);
188 template <class Arch = xsimd::default_arch>
189 inline void PrepareA(const float *input, int8_t *output, float quant_mult,
190 size_t rows, size_t cols) {
191 return Engine<Arch>::PrepareA(input, output, quant_mult, rows, cols);
194 namespace Shift {
196 template <class Arch = xsimd::default_arch>
197 inline void PrepareA(const float *input, uint8_t *output, float quant_mult,
198 size_t rows, size_t cols) {
199 return Engine<Arch>::Shift::PrepareA(input, output, quant_mult, rows, cols);
202 template <class Arch = xsimd::default_arch, class Callback>
203 inline void Multiply(const uint8_t *A, const int8_t *B, size_t A_rows,
204 size_t width, size_t B_cols, Callback C) {
205 return Engine<Arch>::Shift::Multiply(A, B, A_rows, width, B_cols, C);
208 template <class Arch = xsimd::default_arch, class Callback>
209 inline void PrepareBias(const int8_t *B, size_t width, size_t B_cols,
210 Callback C) {
211 return Engine<Arch>::Shift::PrepareBias(B, width, B_cols, C);
214 } // namespace Shift
216 } // namespace gemmology
218 #endif