1 /***************************************************************
4 * __ _ ___ _ __ ___ _ __ ___ ___ | | ___ __ _ _ _ *
5 * / _` |/ _ \ '_ ` _ \| '_ ` _ \ / _ \| |/ _ \ / _` | | | | *
6 * | (_| | __/ | | | | | | | | | | (_) | | (_) | (_| | |_| | *
7 * \__, |\___|_| |_| |_|_| |_| |_|\___/|_|\___/ \__, |\__, | *
12 ***************************************************************/
14 #ifndef GEMMOLOGY_FWD_H
15 #define GEMMOLOGY_FWD_H
20 #include <xsimd/xsimd.hpp>
29 xsimd::batch
<float, Arch
> operator()(xsimd::batch
<int32_t, Arch
> total
, size_t, size_t, size_t);
31 std::tuple
<xsimd::batch
<float, Arch
>, xsimd::batch
<float, Arch
>> operator()(
32 std::tuple
<xsimd::batch
<int32_t, Arch
>, xsimd::batch
<int32_t, Arch
>>
34 size_t, size_t, size_t);
38 const float *bias_addr
;
40 xsimd::batch
<float, Arch
> operator()(xsimd::batch
<float, Arch
> total
, size_t, size_t col_idx
,
43 std::tuple
<xsimd::batch
<float, Arch
>, xsimd::batch
<float, Arch
>>
45 std::tuple
<xsimd::batch
<float, Arch
>, xsimd::batch
<float, Arch
>> total
,
46 size_t, size_t col_idx
, size_t);
52 Write(float *o
) : output_addr(o
) {}
55 void operator()(xsimd::batch
<float, Arch
> result
, size_t row_idx
,
56 size_t col_idx
, size_t col_size
);
58 void operator()(xsimd::batch
<int32_t, Arch
> result
, size_t row_idx
,
59 size_t col_idx
, size_t col_size
);
63 std::tuple
<xsimd::batch
<float, Arch
>, xsimd::batch
<float, Arch
>> result
,
64 size_t row_idx
, size_t col_idx
, size_t col_size
);
68 std::tuple
<xsimd::batch
<int32_t, Arch
>, xsimd::batch
<int32_t, Arch
>>
70 size_t row_idx
, size_t col_idx
, size_t col_size
);
73 struct UnquantizeAndWrite
{
75 Unquantize unquantize
;
78 UnquantizeAndWrite(float factor
, float *output
)
79 : unquantize
{factor
}, write
{output
} {}
82 void operator()(T
const &total
, size_t row_idx
, size_t col_idx
,
86 struct UnquantizeAndAddBiasAndWrite
{
88 Unquantize unquantize
;
92 UnquantizeAndAddBiasAndWrite(float factor
, const float *bias
, float *output
)
93 : unquantize
{factor
}, add_bias
{bias
}, write
{output
} {}
96 void operator()(T
const &total
, size_t row_idx
, size_t col_idx
,
100 } // namespace callbacks
103 // Arch-specific implementation of each routine
105 template <class Arch
> struct Engine
{
107 static void QuantizeU(const float *input
, uint8_t *output
, float quant_mult
,
110 static void Quantize(const float *const input
, int8_t *const output
,
111 float quant_mult
, size_t size
);
113 template <typename IntegerTy
>
114 static void SelectColumnsB(const int8_t *input
, int8_t *output
, size_t rows
,
115 const IntegerTy
*cols_begin
,
116 const IntegerTy
*cols_end
);
118 static void PrepareBTransposed(const float *input
, int8_t *output
,
119 float quant_mult
, size_t cols
, size_t rows
);
121 static void PrepareBQuantizedTransposed(const int8_t *input
, int8_t *output
,
122 size_t cols
, size_t rows
);
124 static void PrepareB(const float *input
, int8_t *output_shadow
,
125 float quant_mult
, size_t rows
, size_t cols
);
127 static void PrepareA(const float *input
, int8_t *output
, float quant_mult
,
128 size_t rows
, size_t cols
);
132 static void PrepareA(const float *input
, uint8_t *output
, float quant_mult
,
133 size_t rows
, size_t cols
);
135 template <class Callback
>
136 static void Multiply(const uint8_t *A
, const int8_t *B
, size_t A_rows
,
137 size_t width
, size_t B_cols
, Callback callback
);
139 template <class Callback
>
140 static void PrepareBias(const int8_t *B
, size_t width
, size_t B_cols
,
146 // Top-level wrappers that mostly match intgemm API
149 template <class Arch
= xsimd::default_arch
>
150 inline void QuantizeU(const float *input
, uint8_t *output
, float quant_mult
,
152 return Engine
<Arch
>::QuantizeU(input
, output
, quant_mult
, size
);
155 template <class Arch
= xsimd::default_arch
>
156 inline void Quantize(const float *const input
, int8_t *const output
,
157 float quant_mult
, size_t size
) {
158 return Engine
<Arch
>::Quantize(input
, output
, quant_mult
, size
);
161 template <class Arch
= xsimd::default_arch
, typename IntegerTy
>
162 inline void SelectColumnsB(const int8_t *input
, int8_t *output
, size_t rows
,
163 const IntegerTy
*cols_begin
,
164 const IntegerTy
*cols_end
) {
165 return Engine
<Arch
>::SelectColumnsB(input
, output
, rows
, cols_begin
,
169 template <class Arch
= xsimd::default_arch
>
170 inline void PrepareBTransposed(const float *input
, int8_t *output
,
171 float quant_mult
, size_t cols
, size_t rows
) {
172 return Engine
<Arch
>::PrepareBTransposed(input
, output
, quant_mult
, cols
,
176 template <class Arch
= xsimd::default_arch
>
177 inline void PrepareBQuantizedTransposed(const int8_t *input
, int8_t *output
,
178 size_t cols
, size_t rows
) {
179 return Engine
<Arch
>::PrepareBQuantizedTransposed(input
, output
, cols
, rows
);
182 template <class Arch
= xsimd::default_arch
>
183 inline void PrepareB(const float *input
, int8_t *output_shadow
,
184 float quant_mult
, size_t rows
, size_t cols
) {
185 return Engine
<Arch
>::PrepareB(input
, output_shadow
, quant_mult
, rows
, cols
);
188 template <class Arch
= xsimd::default_arch
>
189 inline void PrepareA(const float *input
, int8_t *output
, float quant_mult
,
190 size_t rows
, size_t cols
) {
191 return Engine
<Arch
>::PrepareA(input
, output
, quant_mult
, rows
, cols
);
196 template <class Arch
= xsimd::default_arch
>
197 inline void PrepareA(const float *input
, uint8_t *output
, float quant_mult
,
198 size_t rows
, size_t cols
) {
199 return Engine
<Arch
>::Shift::PrepareA(input
, output
, quant_mult
, rows
, cols
);
202 template <class Arch
= xsimd::default_arch
, class Callback
>
203 inline void Multiply(const uint8_t *A
, const int8_t *B
, size_t A_rows
,
204 size_t width
, size_t B_cols
, Callback C
) {
205 return Engine
<Arch
>::Shift::Multiply(A
, B
, A_rows
, width
, B_cols
, C
);
208 template <class Arch
= xsimd::default_arch
, class Callback
>
209 inline void PrepareBias(const int8_t *B
, size_t width
, size_t B_cols
,
211 return Engine
<Arch
>::Shift::PrepareBias(B
, width
, B_cols
, C
);
216 } // namespace gemmology