consistently implemented invertible coordinatetransforms and the
[trakem2.git] / lenscorrection / DistortionCorrectionTask.java
blob9a20c21402838806a32353bfbebee4b4a7b5dc3a
1 /**
2 *
3 */
4 package lenscorrection;
6 import java.awt.Rectangle;
7 import java.awt.geom.AffineTransform;
8 import java.util.ArrayList;
9 import java.util.Collection;
10 import java.util.HashMap;
11 import java.util.List;
12 import java.util.Set;
13 import java.util.concurrent.atomic.AtomicInteger;
15 import lenscorrection.Distortion_Correction.BasicParam;
16 import lenscorrection.Distortion_Correction.PointMatchCollectionAndAffine;
17 import mpicbg.models.Point;
18 import mpicbg.models.PointMatch;
19 import mpicbg.models.Tile;
20 import mpicbg.trakem2.align.AbstractAffineTile2D;
21 import mpicbg.trakem2.align.Align;
22 import mpicbg.trakem2.transform.CoordinateTransform;
23 import mpicbg.trakem2.transform.CoordinateTransformList;
25 import ij.IJ;
26 import ij.gui.GenericDialog;
27 import ini.trakem2.display.Display;
28 import ini.trakem2.display.Displayable;
29 import ini.trakem2.display.Layer;
30 import ini.trakem2.display.Patch;
31 import ini.trakem2.display.Selection;
32 import ini.trakem2.utils.Worker;
33 import ini.trakem2.utils.Bureaucrat;
34 import ini.trakem2.utils.IJError;
35 import ini.trakem2.utils.Utils;
37 /**
38 * Methods collection to be called from the GUI for alignment tasks.
41 final public class DistortionCorrectionTask
43 static public class CorrectDistortionFromSelectionParam extends BasicParam
45 public int firstLayerIndex;
46 public int lastLayerIndex;
47 public boolean clearTransform = false;
48 public boolean visualize = false;
49 public boolean tilesAreInPlace = false;
51 public void addFields( final GenericDialog gd, final Selection selection )
53 addFields( gd );
55 gd.addMessage( "Miscellaneous:" );
56 gd.addCheckbox( "tiles are rougly in place", tilesAreInPlace );
58 gd.addMessage( "Apply Distortion Correction :" );
60 Utils.addLayerRangeChoices( selection.getLayer(), gd );
61 gd.addCheckbox( "clear_present_transforms", clearTransform );
62 gd.addCheckbox( "visualize_distortion_model", visualize );
65 @Override
66 public boolean readFields( final GenericDialog gd )
68 super.readFields( gd );
69 tilesAreInPlace = gd.getNextBoolean();
70 firstLayerIndex = gd.getNextChoiceIndex();
71 lastLayerIndex = gd.getNextChoiceIndex();
72 clearTransform = gd.getNextBoolean();
73 visualize = gd.getNextBoolean();
74 return !gd.invalidNumber();
77 public boolean setup( final Selection selection )
79 final GenericDialog gd = new GenericDialog( "Distortion Correction" );
80 addFields( gd, selection );
83 gd.showDialog();
84 if ( gd.wasCanceled() ) return false;
86 while ( !readFields( gd ) );
88 return true;
91 @Override
92 public CorrectDistortionFromSelectionParam clone()
94 final CorrectDistortionFromSelectionParam p = new CorrectDistortionFromSelectionParam();
95 p.sift.set( sift );
96 p.dimension = dimension;
97 p.expectedModelIndex = expectedModelIndex;
98 p.lambda = lambda;
99 p.maxEpsilon = maxEpsilon;
100 p.minInlierRatio = minInlierRatio;
101 p.rod = rod;
102 p.tilesAreInPlace = tilesAreInPlace;
103 p.firstLayerIndex = firstLayerIndex;
104 p.lastLayerIndex = lastLayerIndex;
105 p.clearTransform = clearTransform;
106 p.visualize = visualize;
108 return p;
114 * Sets a {@link CoordinateTransform} to a list of patches.
116 final static protected class SetCoordinateTransformThread extends Thread
118 final protected List< Patch > patches;
119 final protected CoordinateTransform transform;
120 final protected AtomicInteger ai;
122 public SetCoordinateTransformThread(
123 final List< Patch > patches,
124 final CoordinateTransform transform,
125 final AtomicInteger ai )
127 this.patches = patches;
128 this.transform = transform;
129 this.ai = ai;
132 @Override
133 final public void run()
135 for ( int i = ai.getAndIncrement(); i < patches.size() && !isInterrupted(); i = ai.getAndIncrement() )
137 final Patch patch = patches.get( i );
138 // IJ.log( "Setting transform \"" + transform + "\" for patch \"" + patch.getTitle() + "\"." );
139 patch.setCoordinateTransform( transform );
140 patch.updateMipmaps();
142 IJ.showProgress( i, patches.size() );
147 final static protected void setCoordinateTransform(
148 final List< Patch > patches,
149 final CoordinateTransform transform,
150 final int numThreads )
152 final AtomicInteger ai = new AtomicInteger( 0 );
153 final List< SetCoordinateTransformThread > threads = new ArrayList< SetCoordinateTransformThread >();
155 for ( int i = 0; i < numThreads; ++i )
157 final SetCoordinateTransformThread thread = new SetCoordinateTransformThread( patches, transform, ai );
158 threads.add( thread );
159 thread.start();
163 for ( final Thread thread : threads )
164 thread.join();
166 catch ( InterruptedException e )
168 IJ.log( "Setting CoordinateTransform failed.\n" + e.getMessage() + "\n" + e.getStackTrace() );
174 * Appends a {@link CoordinateTransform} to a list of patches.
176 final static protected class AppendCoordinateTransformThread extends Thread
178 final protected List< Patch > patches;
179 final protected CoordinateTransform transform;
180 final protected AtomicInteger ai;
182 public AppendCoordinateTransformThread(
183 final List< Patch > patches,
184 final CoordinateTransform transform,
185 final AtomicInteger ai )
187 this.patches = patches;
188 this.transform = transform;
189 this.ai = ai;
192 @Override
193 final public void run()
195 for ( int i = ai.getAndIncrement(); i < patches.size() && !isInterrupted(); i = ai.getAndIncrement() )
197 final Patch patch = patches.get( i );
198 patch.appendCoordinateTransform( transform );
199 patch.updateMipmaps();
201 IJ.showProgress( i, patches.size() );
206 final static protected void appendCoordinateTransform(
207 final List< Patch > patches,
208 final CoordinateTransform transform,
209 final int numThreads )
211 final AtomicInteger ai = new AtomicInteger( 0 );
212 final List< AppendCoordinateTransformThread > threads = new ArrayList< AppendCoordinateTransformThread >();
214 for ( int i = 0; i < numThreads; ++i )
216 final AppendCoordinateTransformThread thread = new AppendCoordinateTransformThread( patches, transform, ai );
217 threads.add( thread );
218 thread.start();
222 for ( final Thread thread : threads )
223 thread.join();
225 catch ( InterruptedException e )
227 IJ.log( "Appending CoordinateTransform failed.\n" + e.getMessage() + "\n" + e.getStackTrace() );
231 final static public CorrectDistortionFromSelectionParam correctDistortionFromSelectionParam = new CorrectDistortionFromSelectionParam();
233 final static public Bureaucrat correctDistortionFromSelectionTask ( final Selection selection )
235 Worker worker = new Worker("Distortion Correction", false, true) {
236 public void run() {
237 startedWorking();
238 try {
239 correctDistortionFromSelection( selection );
240 Display.repaint(selection.getLayer());
241 } catch (Throwable e) {
242 IJError.print(e);
243 } finally {
244 finishedWorking();
247 public void cleanup() {
248 if (!selection.isEmpty())
249 selection.getLayer().getParent().undoOneStep();
252 return Bureaucrat.createAndStart( worker, selection.getProject() );
255 final static public Bureaucrat correctDistortionFromSelection( final Selection selection )
257 final List< Patch > patches = new ArrayList< Patch >();
258 for ( Displayable d : Display.getFront().getSelection().getSelected() )
259 if ( d instanceof Patch ) patches.add( ( Patch )d );
261 if ( patches.size() < 2 )
263 Utils.log("No images in the selection.");
264 return null;
267 final Worker worker = new Worker( "Lens correction" )
269 final public void run()
273 startedWorking();
275 if ( !correctDistortionFromSelectionParam.setup( selection ) ) return;
277 final CorrectDistortionFromSelectionParam p = correctDistortionFromSelectionParam.clone();
278 final Align.ParamOptimize ap = Align.paramOptimize.clone();
279 ap.sift.set( p.sift );
280 ap.desiredModelIndex = ap.expectedModelIndex = p.expectedModelIndex;
281 ap.maxEpsilon = p.maxEpsilon;
282 ap.minInlierRatio = p.minInlierRatio;
283 ap.rod = p.rod;
285 /** Get all patches that will be affected. */
286 final List< Patch > allPatches = new ArrayList< Patch >();
287 for ( final Layer l : selection.getLayer().getParent().getLayers().subList( p.firstLayerIndex, p.lastLayerIndex + 1 ) )
288 for ( Displayable d : l.getDisplayables( Patch.class ) )
289 allPatches.add( ( Patch )d );
291 /** Unset the coordinate transforms of all patches if desired. */
292 if ( p.clearTransform )
294 setTaskName( "Clearing present transforms" );
295 setCoordinateTransform( allPatches, null, Runtime.getRuntime().availableProcessors() );
296 Display.repaint();
299 setTaskName( "Establishing SIFT correspondences" );
301 List< AbstractAffineTile2D< ? > > tiles = new ArrayList< AbstractAffineTile2D< ? > >();
302 List< AbstractAffineTile2D< ? > > fixedTiles = new ArrayList< AbstractAffineTile2D< ? > > ();
303 List< Patch > fixedPatches = new ArrayList< Patch >();
304 final Displayable active = selection.getActive();
305 if ( active != null && active instanceof Patch )
306 fixedPatches.add( ( Patch )active );
307 Align.tilesFromPatches( ap, patches, fixedPatches, tiles, fixedTiles );
309 final List< AbstractAffineTile2D< ? >[] > tilePairs = new ArrayList< AbstractAffineTile2D< ? >[] >();
311 if ( p.tilesAreInPlace )
312 AbstractAffineTile2D.pairOverlappingTiles( tiles, tilePairs );
313 else
314 AbstractAffineTile2D.pairTiles( tiles, tilePairs );
316 final AbstractAffineTile2D< ? > fixedTile = fixedTiles.iterator().next();
318 Align.connectTilePairs( ap, tiles, tilePairs, Runtime.getRuntime().availableProcessors() );
321 /** Shift all local coordinates into the original image frame */
322 for ( final AbstractAffineTile2D< ? > tile : tiles )
324 final Rectangle box = tile.getPatch().getCoordinateTransformBoundingBox();
325 for ( final PointMatch m : tile.getMatches() )
327 final float[] l = m.getP1().getL();
328 final float[] w = m.getP1().getW();
329 l[ 0 ] += box.x;
330 l[ 1 ] += box.y;
331 w[ 0 ] = l[ 0 ];
332 w[ 1 ] = l[ 1 ];
336 if ( Thread.currentThread().isInterrupted() ) return;
338 List< Set< Tile< ? > > > graphs = AbstractAffineTile2D.identifyConnectedGraphs( tiles );
339 if ( graphs.size() > 1 )
340 IJ.log( "Could not interconnect all images with correspondences. " );
342 final List< AbstractAffineTile2D< ? > > interestingTiles;
344 /** Find largest graph. */
345 Set< Tile< ? > > largestGraph = null;
346 for ( Set< Tile< ? > > graph : graphs )
347 if ( largestGraph == null || largestGraph.size() < graph.size() )
348 largestGraph = graph;
350 interestingTiles = new ArrayList< AbstractAffineTile2D< ? > >();
351 for ( Tile< ? > t : largestGraph )
352 interestingTiles.add( ( AbstractAffineTile2D< ? > )t );
354 if ( Thread.currentThread().isInterrupted() ) return;
356 IJ.log( "Estimating lens model:" );
358 /* initialize with pure affine */
359 Align.optimizeTileConfiguration( ap, interestingTiles, fixedTiles );
361 /* measure the current error */
362 double e = 0;
363 int n = 0;
364 for ( final AbstractAffineTile2D< ? > t : interestingTiles )
365 for ( final PointMatch pm : t.getMatches() )
367 e += pm.getDistance();
368 ++n;
370 e /= n;
372 double dEpsilon_i = 0;
373 double epsilon_i = e;
374 double dEpsilon_0 = 0;
375 NonLinearTransform lensModel = null;
377 IJ.log( "0: epsilon = " + e );
379 /* Store original point locations */
380 final HashMap< Point, Point > originalPoints = new HashMap< Point, Point >();
381 for ( final AbstractAffineTile2D< ? > t : interestingTiles )
382 for ( final PointMatch pm : t.getMatches() )
383 originalPoints.put( pm.getP1(), pm.getP1().clone() );
385 for ( int i = 1; i < 2 || dEpsilon_i <= dEpsilon_0 / 100; ++i )
387 /* Some data shuffling for the lens correction interface */
388 final List< PointMatchCollectionAndAffine > matches = new ArrayList< PointMatchCollectionAndAffine >();
389 for ( AbstractAffineTile2D< ? >[] tilePair : tilePairs )
391 final AffineTransform a = tilePair[ 0 ].createAffine();
392 a.preConcatenate( tilePair[ 1 ].getModel().createInverseAffine() );
393 final Collection< PointMatch > commonMatches = new ArrayList< PointMatch >();
394 tilePair[ 0 ].commonPointMatches( tilePair[ 1 ], commonMatches );
395 final Collection< PointMatch > originalCommonMatches = new ArrayList< PointMatch >();
396 for ( final PointMatch pm : commonMatches )
397 originalCommonMatches.add( new PointMatch(
398 originalPoints.get( pm.getP1() ),
399 originalPoints.get( pm.getP2() ) ) );
400 matches.add( new PointMatchCollectionAndAffine( a, originalCommonMatches ) );
403 setTaskName( "Estimating lens distortion correction" );
405 lensModel = Distortion_Correction.createInverseDistortionModel(
406 matches,
407 p.dimension,
408 p.lambda,
409 ( int )fixedTile.getWidth(),
410 ( int )fixedTile.getHeight() );
412 /* update local points */
413 for ( final AbstractAffineTile2D< ? > t : interestingTiles )
414 for ( final PointMatch pm : t.getMatches() )
416 final Point currentPoint = pm.getP1();
417 final Point originalPoint = originalPoints.get( currentPoint );
418 final float[] l = currentPoint.getL();
419 final float[] lo = originalPoint.getL();
420 l[ 0 ] = lo[ 0 ];
421 l[ 1 ] = lo[ 1 ];
422 lensModel.applyInPlace( l );
425 /* re-optimize */
426 Align.optimizeTileConfiguration( ap, interestingTiles, fixedTiles );
428 /* measure the current error */
429 e = 0;
430 n = 0;
431 for ( final AbstractAffineTile2D< ? > t : interestingTiles )
432 for ( final PointMatch pm : t.getMatches() )
434 e += pm.getDistance();
435 ++n;
437 e /= n;
439 dEpsilon_i = e - epsilon_i;
440 epsilon_i = e;
441 if ( i == 1 ) dEpsilon_0 = dEpsilon_i;
443 IJ.log( i + ": epsilon = " + e );
444 IJ.log( i + ": epsilon = " + dEpsilon_i );
447 if ( lensModel != null )
449 if ( p.visualize )
451 setTaskName( "Visualizing lens distortion correction" );
452 lensModel.visualizeSmall( p.lambda );
455 setTaskName( "Applying lens distortion correction" );
457 appendCoordinateTransform( allPatches, lensModel, Runtime.getRuntime().availableProcessors() );
459 IJ.log( "Done." );
461 else
462 IJ.log( "No lens model found." );
464 Display.repaint();
466 catch ( Exception e ) { IJError.print( e ); }
467 finally { finishedWorking(); }
471 return Bureaucrat.createAndStart( worker, selection.getProject() );