Fixed wrong prediction for incomplete ranges.
[fic.git] / modules / saupePredictor.cpp
blob2b39cb72e00ceda9f058bf805ba9329d1e309b87
1 #include "saupePredictor.h"
2 #include "stdDomains.h" // because of HalfShrinker
3 using namespace std;
5 IStdEncPredictor::IOneRangePredictor* MSaupePredictor
6 ::newPredictor(const NewPredictorData &data) {
7 // ensure the levelTrees vector is long enough
8 int level= data.rangeBlock->level;
9 if ( level >= (int)levelTrees.size() )
10 levelTrees.resize( level+1, (Tree*)0 );
11 // ensure the tree is built for the level
12 Tree *tree= levelTrees[level];
13 if (!tree)
14 tree= levelTrees[level]= createTree(data);
15 ASSERT(tree);
16 // get the max. number of domains to predict and create the predictor
17 int maxPredicts= (int)ceil(maxPredCoeff()*tree->count);
18 if (maxPredicts<=0)
19 maxPredicts= 1;
20 OneRangePredictor *result=
21 new OneRangePredictor( data, settingsInt(ChunkSize), *tree, maxPredicts );
23 #ifndef NDEBUG // collecting debugging stats
24 maxpred+= tree->count*(data.allowRotations?8:1)*(data.allowInversion?2:1);
25 result->predicted= &predicted;
26 #endif
28 return result;
31 namespace MatrixWalkers {
32 /** Creates a Checked structure to walk a linear array as a matrix */
33 template<class T,class I> Checked<T,I> makeLinearizer(T* start,I colSkip,I width,I height) {
34 return Checked<T,I>
35 ( MatrixSlice<T,I>::makeRaw(start,colSkip), Block(0,0,width,height) );
38 /** Walker that iterates over rectangles in a matrix and returns their sums,
39 * to be constructed by ::makeSumWalker */
40 template < class SumT, class PixT, class I >
41 struct SumWalker {
42 SummedMatrix<SumT,PixT,I> matrix;
43 I width, height, y0;//, xEnd, yEnd;
44 I x, y;
46 //bool outerCond() { return x!=xEnd; }
47 void outerStep() { x+= width; }
49 void innerInit() { y= y0; }
50 //bool innerCond() { return y!=yEnd; }
51 void innerStep() { y+= height; }
53 SumT get() { return matrix.getValueSum(x,y,x+width,y+height); }
55 /** Constructs a square SumWalker on \p matrix starting on [\p x0,\p y0], iterating
56 * over squares of level (\p inLevel-\p outLevel) and covering a square of level \p outLevel */
57 template < class SumT, class PixT, class I >
58 SumWalker<SumT,PixT,I> makeSumWalker
59 ( const SummedMatrix<SumT,PixT,I> &matrix, I x0, I y0, I inLevel, I outLevel ) {
60 ASSERT( x0>=0 && y0>=0 && inLevel>=0 && outLevel>=0 && inLevel>=outLevel );
61 SumWalker<SumT,PixT,I> result;
62 result.matrix= matrix;
63 result.width= result.height= powers[inLevel-outLevel];
64 result.y0= y0;
65 //result.xEnd= x0+powers[inLevel];
66 //result.yEnd= y0+powers[inLevel];
67 result.x= x0;
68 DEBUG_ONLY( result.y= numeric_limits<I>::max(); )
69 return result;
72 /** Similar to HalfShrinker, but instead of multiplying the sums with 0.25
73 * it uses an arbitrary number */
74 template<class T,class I=PtrInt,class R=Real>
75 struct HalfShrinkerMul: public HalfShrinker<T,I> { ROTBASE_INHERIT
76 R toMul;
78 HalfShrinkerMul( TMatrix matrix, R toMul_, I x0=0, I y0=0 )
79 : HalfShrinker<T,I>(matrix,x0,y0), toMul(toMul_) {}
81 T get() {
82 TMatrix &c= current;
83 T *cs= c.start;
84 return toMul * ( cs[0] + cs[1] + cs[c.colSkip] + cs[c.colSkip+1] );
87 } // MatrixWalkers namespace
88 namespace NOSPACE {
89 typedef MSaupePredictor::KDReal KDReal;
90 /** The common part of ::refineDomain and ::refineRange. It does the actual shrinking
91 * (if needed). \p multiply parameter would normalize pixels, not sums. */
92 inline static void refineBlock( const SummedPixels &pixMatrix, int x0, int y0
93 , int predWidth, int predHeight, Real multiply, Real avg
94 , int realLevel, int predLevel, SReal *pixelResult ) {
95 using namespace MatrixWalkers;
96 // adjust the multiplication coefficients for normalizing from sums of values
97 int levelDiff= realLevel-predLevel;
98 if (levelDiff)
99 multiply= ldexp(multiply,-2*levelDiff);
100 ASSERT( finite(multiply) && finite(avg) );
101 // construct the operator and the walker on the result
102 AddMulCopy<Real> oper( -avg, multiply );
103 Checked<KDReal> resWalker=
104 makeLinearizer( pixelResult, powers[predLevel], predWidth, predHeight );
105 // decide the shrinking method
106 if (levelDiff==0) // no shrinking
107 walkOperate( resWalker, Rotation_0<SReal>(pixMatrix.pixels,x0,y0), oper );
108 else if (levelDiff==1) // shrinking by groups of four
109 walkOperate( resWalker
110 , HalfShrinkerMul<SReal>(pixMatrix.pixels,multiply,x0,y0), oper );
111 else // levelDiff>=2 // shrinking by bigger groups - using the summer
112 walkOperate( resWalker
113 , makeSumWalker(pixMatrix,x0,y0,realLevel,predLevel), oper );
116 void MSaupePredictor::refineDomain( const SummedPixels &pixMatrix, int x0, int y0
117 , bool allowInversion, int realLevel, int predLevel, SReal *pixelResult ) {
118 // compute the average and standard deviation (data needed for normalization)
119 int realSide= powers[realLevel];
120 Real sum, sum2;
121 pixMatrix.getSums(x0,y0,x0+realSide,y0+realSide).unpack(sum,sum2);
122 Real avg= ldexp(sum,-2*realLevel); // means avg= sum / (2^realLevel)^2
123 // the same as = 1 / sqrt( sum2 - sqr(sum)/pixelCount ) );
124 Real multiply= 1 / sqrt( sum2 - sqr(ldexp(sum,-realLevel)) );
125 if ( !finite(multiply) )
126 multiply= 1; // it would be better not to add the domains into the tree
127 // if inversion is allowed and the first pixel is below average, then invert the block
128 if ( allowInversion && pixMatrix.pixels[x0][y0] < avg )
129 multiply= -multiply;
130 // do the actual work
131 int predSide= powers[predLevel];
132 refineBlock( pixMatrix, x0, y0, predSide, predSide, multiply, avg
133 , realLevel, predLevel, pixelResult );
135 bool MSaupePredictor::refineRange
136 ( const NewPredictorData &data, int predLevel, SReal *pixelResult ) {
137 const ISquareRanges::RangeNode &rb= *data.rangeBlock;
138 int levelDiff= rb.level-predLevel
139 , shrWidth= rShift(rb.width(),levelDiff)
140 , shrHeight= rShift(rb.height(),levelDiff);
141 Real avg, multiply;
142 if ( data.isRegular || predLevel==rb.level ) { // no cropping of the block
143 // compute the average and standard deviation (data needed for normalization)
144 avg= data.rSum/data.pixCount;
145 multiply= ( data.isRegular ? powers[rb.level] : sqrt(data.pixCount) ) / data.rnDev;
146 } else { // the block needs cropping
147 if ( shrWidth*shrHeight <= 1 )
148 return false;
149 // crop the range block so it doesn't contain parts of shrinked pixels
150 int mask= powers[rb.level-predLevel]-1;
151 int adjWidth= rb.width() & ~mask;
152 int adjHeight= rb.height() & ~mask;
153 int pixCount= adjWidth*adjHeight;
154 // compute the coefficients needed for normalization, assertion tests in ::refineBlock
155 Real sum, sum2;
156 data.rangePixels->getSums( rb.x0, rb.y0, rb.x0+adjWidth, rb.y0+adjHeight )
157 .unpack(sum,sum2);
158 avg= sum/pixCount;
159 multiply= 1 / sqrt( sum2 - sqr(sum)/pixCount );
161 // do the actual work
162 refineBlock( *data.rangePixels, rb.x0, rb.y0, shrWidth, shrHeight
163 , multiply, avg, rb.level, predLevel, pixelResult );
164 return true;
167 MSaupePredictor::Tree* MSaupePredictor::createTree(const NewPredictorData &data) {
168 // compute some accelerators
169 const ISquareEncoder::LevelPoolInfos::value_type &poolInfos= *data.poolInfos;
170 const int domainCount= poolInfos.back().indexBegin
171 , realLevel= data.rangeBlock->level
172 , predLevel= getPredLevel(realLevel)
173 , realSide= powers[realLevel]
174 , predPixCount= powers[2*predLevel];
175 ASSERT(realLevel>=predLevel);
176 // create space for temporary domain pixels, can be too big to be on the stack
177 KDReal *domPix= new KDReal[ domainCount * predPixCount ];
178 // init domain-blocks from every pool
179 KDReal *domPixNow= domPix;
180 int poolCount= data.pools->size();
181 for (int poolID=0; poolID<poolCount; ++poolID) {
182 // check we are on the place we want to be; get the current pool, density, etc.
183 ASSERT( domPixNow-domPix == poolInfos[poolID].indexBegin*predPixCount );
184 const ISquareDomains::Pool &pool= (*data.pools)[poolID];
185 int density= poolInfos[poolID].density;
186 if (!density) // no domains in this pool for this level
187 continue;
188 int poolXend= density*getCountForDensity( pool.width, density, realSide );
189 int poolYend= density*getCountForDensity( pool.height, density, realSide );
190 // handle the domain block on [x0,y0] (for each in the pool)
191 for (int x0=0; x0<poolXend; x0+=density)
192 for (int y0=0; y0<poolYend; y0+=density) {
193 refineDomain( pool, x0, y0, data.allowInversion, realLevel, predLevel, domPixNow );
194 domPixNow+= predPixCount;
197 ASSERT( domPixNow-domPix == domainCount*predPixCount ); // check we are just at the end
198 // create the tree from obtained data
199 Tree *result= Tree::Builder
200 ::makeTree( domPix, predPixCount, domainCount, &Tree::Builder::chooseApprox );
201 // clean up temporaries, return the tree
202 delete[] domPix;
203 return result;
206 namespace MatrixWalkers {
207 /** Transformer performing an affine function */
208 template<class T> struct AddMulCopyTo2nd {
209 T toAdd, toMul;
211 AddMulCopyTo2nd(T add,T mul)
212 : toAdd(add), toMul(mul) {}
214 template<class R1,class R2> void operator()(R1 f,R2 &res) const
215 { res= (f+toAdd)*toMul; }
216 void innerEnd() const {}
218 /** A simple assigning operator - assigns its second argument into the first one */
219 struct Assigner: public OperatorBase {
220 template<class R1,class R2> void operator()(const R1 &src,R2 &dest) const
221 { dest= src; }
223 /** Transformer performing sign change */
224 struct SignChanger {
225 template<class R1,class R2> void operator()(R1 src,R2 &dest) const
226 { dest= -src; }
229 MSaupePredictor::OneRangePredictor::OneRangePredictor
230 ( const NewPredictorData &data, int chunkSize_, const Tree &tree, int maxPredicts )
231 : chunkSize(chunkSize_), predsRemain(maxPredicts)
232 , firstChunk(true), allowRotations(data.allowRotations), isRegular(data.isRegular)
234 // compute some accelerators, allocate space for normalized range (+rotations,inversion)
235 int rotationCount= allowRotations ? 8 : 1;
236 heapCount= rotationCount * (data.allowInversion ? 2 : 1);
237 points= new KDReal[tree.length*heapCount];
239 // if the block isn't regular, fill the space with NaNs (to be left on unused places)
240 if (!isRegular)
241 fill( points, points+tree.length, numeric_limits<KDReal>::quiet_NaN() );
243 // find out the prediction level (the level of the size) and the size of prediction sides
244 int predLevel= log2ceil(tree.length)/2;
245 int predSideLen= powers[predLevel];
246 ASSERT( powers[2*predLevel] == tree.length );
247 // compute SE-normalizing accelerator
248 errorNorm.initialize(data,predLevel);
250 if (!refineRange( data, predLevel, points ))
251 return;
253 // if rotations are allowed, rotate the refined block
254 if (allowRotations) {
255 using namespace MatrixWalkers;
256 // create walker for the refined (and not rotated) block (including eventual NaNs)
257 Checked<KDReal> refBlockWalker=
258 makeLinearizer( points, predSideLen, predSideLen, predSideLen );
260 MatrixSlice<KDReal> rotMatrix= MatrixSlice<KDReal>::makeRaw(points,predSideLen);
261 Block shiftedBlock( 0, 0, predSideLen, predSideLen );
263 for (int rot=1; rot<rotationCount; ++rot) {
264 rotMatrix.start+= tree.length; // shifting the matrix to the next rotation
265 walkOperateCheckRotate
266 ( refBlockWalker, Assigner(), rotMatrix, shiftedBlock, rot );
268 ASSERT( rotMatrix.start == points+tree.length*(rotationCount-1) );
271 // create inverse of the rotations if needed
272 if (data.allowInversion) {
273 KDReal *pointsMiddle= points+tree.length*rotationCount;
274 FieldMath::transform2
275 ( points, pointsMiddle, pointsMiddle, MatrixWalkers::SignChanger() );
278 // create all the heaps and initialize their infos (and make a heap of the infos)
279 heaps.reserve(heapCount);
280 infoHeap.reserve(heapCount);
281 for (int i=0; i<heapCount; ++i) {
282 PointHeap *heap= new PointHeap( tree, points+i*tree.length, !data.isRegular );
283 heaps.push_back(heap);
284 infoHeap.push_back(HeapInfo( i, heap->getTopSE() ));
286 // build the heap from heap-informations
287 make_heap( infoHeap.begin(), infoHeap.end() );
290 MSaupePredictor::Predictions& MSaupePredictor::OneRangePredictor
291 ::getChunk(float maxPredictedSE,Predictions &store) {
292 if ( infoHeap.empty() || predsRemain<=0 ) {
293 store.clear();
294 return store;
296 ASSERT( PtrInt(heaps.size())==heapCount && PtrInt(infoHeap.size())<=heapCount );
297 // get the number of predictions to make (may be larger for the first chunk)
298 int predCount= chunkSize;
299 if (firstChunk) {
300 firstChunk= false;
301 if (heapCount>predCount)
302 predCount= heapCount;
304 // check the limit for prediction count
305 if (predCount>predsRemain)
306 predCount= predsRemain;
307 predsRemain-= predCount;
308 // compute the max. normalized SE to predict
309 float maxNormalizedSE= errorNorm.normSE(maxPredictedSE);
310 // make a local working copy for the result (the prediction), adjust its size
311 Predictions result;
312 swap(result,store); // swapping is the quickest way
313 result.resize(predCount);
314 // generate the predictions
315 for (Predictions::iterator it=result.begin(); it!=result.end(); ++it) {
316 pop_heap( infoHeap.begin(), infoHeap.end() );
317 HeapInfo &bestInfo= infoHeap.back();
318 // if the error is too high, cut the vector and exit the cycle
319 if ( bestInfo.bestError > maxNormalizedSE ) {
320 result.erase( it, result.end() );
321 infoHeap.clear(); // to be able to exit more quickly in the next call
322 break;
324 // fill the prediction and pop the heap
325 ASSERT( 0<=bestInfo.index && bestInfo.index<heapCount );
326 PointHeap &bestHeap= *heaps[bestInfo.index];
327 it->domainID= isRegular
328 ? bestHeap.popLeaf<false>(maxNormalizedSE)
329 : bestHeap.popLeaf<true>(maxNormalizedSE);
330 it->rotation= allowRotations ? bestInfo.index%8 : 0; // modulo - for the case of inversion
331 // check for emptying the heap
332 if ( !bestHeap.isEmpty() ) {
333 // rebuild the infoHeap heap
334 bestInfo.bestError= bestHeap.getTopSE();
335 push_heap( infoHeap.begin(), infoHeap.end() );
336 } else { // just emptied a heap
337 infoHeap.pop_back();
338 // check for emptying the last heap
339 if ( infoHeap.empty() )
340 break;
343 // return the result
344 swap(result,store);
346 #ifndef NDEBUG
347 *predicted+= store.size();
348 #endif
350 return store;