separated CoordinateTransform parts of lenscorrection.NonLinearTransform
[trakem2.git] / TrakEM2_ / src / main / java / lenscorrection / NonLinearTransform.java
blob59f1f4fd2e88a74d90fafac933984cb1ba975988
1 /**
3 Copyright (C) 2008 Verena Kaynig.
5 This program is free software; you can redistribute it and/or
6 modify it under the terms of the GNU General Public License
7 as published by the Free Software Foundation (http://www.gnu.org/licenses/gpl.txt )
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 GNU General Public License for more details.
14 You should have received a copy of the GNU General Public License
15 along with this program; if not, write to the Free Software
16 Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
17 **/
19 /* **************************************************************** *
20 * Representation of a non linear transform by explicit polynomial
21 * kernel expansion.
23 * TODO:
24 * - make different kernels available
25 * - inverse transform for visualization
26 * - improve image interpolation
27 * - apply and applyInPlace should use precalculated transform?
28 * (What about out of image range pixels?)
30 * Author: Verena Kaynig
31 * Kontakt: verena.kaynig@inf.ethz.ch
33 * **************************************************************** */
35 package lenscorrection;
37 import ij.ImagePlus;
38 import ij.io.FileSaver;
39 import ij.process.ByteProcessor;
40 import ij.process.ColorProcessor;
41 import ij.process.FloatProcessor;
42 import ij.process.ImageProcessor;
44 import java.awt.Color;
45 import java.awt.geom.GeneralPath;
46 import java.io.BufferedReader;
47 import java.io.BufferedWriter;
48 import java.io.FileNotFoundException;
49 import java.io.FileOutputStream;
50 import java.io.FileReader;
51 import java.io.IOException;
52 import java.io.OutputStreamWriter;
54 import mpicbg.trakem2.transform.NonLinearCoordinateTransform;
55 import Jama.Matrix;
58 public class NonLinearTransform extends NonLinearCoordinateTransform {
60 private double[][][] transField = null;
62 public int getDimension(){ return dimension; }
63 /** Deletes all dimension dependent properties */
64 public void setDimension( final int dimension )
66 this.dimension = dimension;
67 length = (dimension + 1)*(dimension + 2)/2;
69 beta = new double[length][2];
70 normMean = new double[length];
71 normVar = new double[length];
73 for (int i=0; i < length; i++){
74 normMean[i] = 0;
75 normVar[i] = 1;
77 transField = null;
78 precalculated = false;
81 private boolean precalculated = false;
83 public int getMinNumMatches()
85 return length;
89 public void fit( final double x[][], final double y[][], final double lambda )
91 final double[][] expandedX = kernelExpandMatrixNormalize( x );
93 final Matrix phiX = new Matrix( expandedX, expandedX.length, length );
94 final Matrix phiXTransp = phiX.transpose();
96 final Matrix phiXProduct = phiXTransp.times( phiX );
98 final int l = phiXProduct.getRowDimension();
99 final double lambda2 = 2 * lambda;
101 for (int i = 0; i < l; ++i )
102 phiXProduct.set( i, i, phiXProduct.get( i, i ) + lambda2 );
104 final Matrix phiXPseudoInverse = phiXProduct.inverse();
105 final Matrix phiXProduct2 = phiXPseudoInverse.times( phiXTransp );
106 final Matrix betaMatrix = phiXProduct2.times( new Matrix( y, y.length, 2 ) );
108 setBeta( betaMatrix.getArray() );
111 public void estimateDistortion( final double hack1[][], final double hack2[][], final double transformParams[][], final double lambda, final int w, final int h )
113 beta = new double[ length ][ 2 ];
114 normMean = new double[ length ];
115 normVar = new double[ length ];
117 for ( int i = 0; i < length; i++ )
119 normMean[ i ] = 0;
120 normVar[ i ] = 1;
123 width = w;
124 height = h;
126 /* TODO Find out how to keep some target points fixed (check fit method of NLT which is supposed to be exclusively forward) */
127 final double expandedX[][] = kernelExpandMatrixNormalize( hack1 );
128 final double expandedY[][] = kernelExpandMatrix( hack2 );
130 final int s = expandedX[ 0 ].length;
131 Matrix S1 = new Matrix( 2 * s, 2 * s );
132 Matrix S2 = new Matrix( 2 * s, 1 );
134 for ( int i = 0; i < expandedX.length; ++i )
136 final Matrix xk_ij = new Matrix( expandedX[ i ], 1 );
137 final Matrix xk_ji = new Matrix( expandedY[ i ], 1 );
139 final Matrix yk1a = xk_ij.minus( xk_ji.times( transformParams[ i ][ 0 ] ) );
140 final Matrix yk1b = xk_ij.times( 0.0 ).minus( xk_ji.times( -transformParams[ i ][ 2 ] ) );
141 final Matrix yk2a = xk_ij.times( 0.0 ).minus( xk_ji.times( -transformParams[ i ][ 1 ] ) );
142 final Matrix yk2b = xk_ij.minus( xk_ji.times( transformParams[ i ][ 3 ] ) );
144 final Matrix y = new Matrix( 2, 2 * s );
145 y.setMatrix( 0, 0, 0, s - 1, yk1a );
146 y.setMatrix( 0, 0, s, 2 * s - 1, yk1b );
147 y.setMatrix( 1, 1, 0, s - 1, yk2a );
148 y.setMatrix( 1, 1, s, 2 * s - 1, yk2b );
150 final Matrix xk = new Matrix( 2, 2 * expandedX[ 0 ].length );
151 xk.setMatrix( 0, 0, 0, s - 1, xk_ij );
152 xk.setMatrix( 1, 1, s, 2 * s - 1, xk_ij );
154 final double[] vals = { hack1[ i ][ 0 ], hack1[ i ][ 1 ] };
155 final Matrix c = new Matrix( vals, 2 );
157 final Matrix X = xk.transpose().times( xk ).times( lambda );
158 final Matrix Y = y.transpose().times( y );
160 S1 = S1.plus( Y.plus( X ) );
162 final double trans1 = ( transformParams[ i ][ 2 ] * transformParams[ i ][ 5 ] - transformParams[ i ][ 0 ] * transformParams[ i ][ 4 ] );
163 final double trans2 = ( transformParams[ i ][ 1 ] * transformParams[ i ][ 4 ] - transformParams[ i ][ 3 ] * transformParams[ i ][ 5 ] );
164 final double[] trans = { trans1, trans2 };
166 final Matrix translation = new Matrix( trans, 2 );
167 final Matrix YT = y.transpose().times( translation );
168 final Matrix XC = xk.transpose().times( c ).times( lambda );
170 S2 = S2.plus( YT.plus( XC ) );
172 final Matrix regularize = Matrix.identity( S1.getRowDimension(), S1.getColumnDimension() );
173 final Matrix newBeta = new Matrix( S1.plus( regularize.times( 0.001 ) ).inverse().times( S2 ).getColumnPackedCopy(), s );
175 setBeta( newBeta.getArray() );
178 public NonLinearTransform(final double[][] b, final double[] nm, final double[] nv, final int d, final int w, final int h){
179 beta = b;
180 normMean = nm;
181 normVar = nv;
182 dimension = d;
183 length = (dimension + 1)*(dimension + 2)/2;
184 width = w;
185 height = h;
188 public NonLinearTransform(final int d, final int w, final int h){
189 dimension = d;
190 length = (dimension + 1)*(dimension + 2)/2;
192 beta = new double[length][2];
193 normMean = new double[length];
194 normVar = new double[length];
196 for (int i=0; i < length; i++){
197 normMean[i] = 0;
198 normVar[i] = 1;
201 width = w;
202 height = h;
205 public NonLinearTransform(){};
207 public NonLinearTransform(final String filename){
208 this.load(filename);
211 public NonLinearTransform(final double[][] coeffMatrix, final int w, final int h){
212 length = coeffMatrix.length;
213 beta = new double[length][2];
214 normMean = new double[length];
215 normVar = new double[length];
216 width = w;
217 height = h;
218 dimension = (int)(-1.5 + Math.sqrt(0.25 + 2*length));
220 for(int i=0; i<length; i++){
221 beta[i][0] = coeffMatrix[0][i];
222 beta[i][1] = coeffMatrix[1][i];
223 normMean[i] = coeffMatrix[2][i];
224 normVar[i] = coeffMatrix[3][i];
229 void precalculateTransfom(){
230 transField = new double[width][height][2];
231 //double minX = width, minY = height, maxX = 0, maxY = 0;
233 for (int x=0; x<width; x++){
234 for (int y=0; y<height; y++){
235 final double[] position = {x,y};
236 final double[] featureVector = kernelExpand(position);
237 final double[] newPosition = multiply(beta, featureVector);
239 if ((newPosition[0] < 0) || (newPosition[0] >= width) ||
240 (newPosition[1] < 0) || (newPosition[1] >= height))
242 transField[x][y][0] = -1;
243 transField[x][y][1] = -1;
244 continue;
247 transField[x][y][0] = newPosition[0];
248 transField[x][y][1] = newPosition[1];
250 //minX = Math.min(minX, x);
251 //minY = Math.min(minY, y);
252 //maxX = Math.max(maxX, x);
253 //maxY = Math.max(maxY, y);
258 precalculated = true;
261 public double[][] getCoefficients(){
262 final double[][] coeffMatrix = new double[4][length];
264 for(int i=0; i<length; i++){
265 coeffMatrix[0][i] = beta[i][0];
266 coeffMatrix[1][i] = beta[i][1];
267 coeffMatrix[2][i] = normMean[i];
268 coeffMatrix[3][i] = normVar[i];
271 return coeffMatrix;
274 public void setBeta(final double[][] b){
275 beta = b;
276 //FIXME: test if normMean and normVar are still valid for this beta
279 public void print(){
280 System.out.println("beta:");
281 for (int i=0; i < beta.length; i++){
282 for (int j=0; j < beta[i].length; j++){
283 System.out.print(beta[i][j]);
284 System.out.print(" ");
286 System.out.println();
289 System.out.println("normMean:");
290 for (int i=0; i < normMean.length; i++){
291 System.out.print(normMean[i]);
292 System.out.print(" ");
295 System.out.println("normVar:");
296 for (int i=0; i < normVar.length; i++){
297 System.out.print(normVar[i]);
298 System.out.print(" ");
301 System.out.println("Image size:");
302 System.out.println("width: " + width + " height: " + height);
304 System.out.println();
308 public void save( final String filename )
310 try{
311 final BufferedWriter out = new BufferedWriter(
312 new OutputStreamWriter(
313 new FileOutputStream( filename) ) );
314 try{
315 out.write("Kerneldimension");
316 out.newLine();
317 out.write(Integer.toString(dimension));
318 out.newLine();
319 out.newLine();
320 out.write("number of rows");
321 out.newLine();
322 out.write(Integer.toString(length));
323 out.newLine();
324 out.newLine();
325 out.write("Coefficients of the transform matrix:");
326 out.newLine();
327 for (int i=0; i < length; i++){
328 String s = Double.toString(beta[i][0]);
329 s += " ";
330 s += Double.toString(beta[i][1]);
331 out.write(s);
332 out.newLine();
334 out.newLine();
335 out.write("normMean:");
336 out.newLine();
337 for (int i=0; i < length; i++){
338 out.write(Double.toString(normMean[i]));
339 out.newLine();
341 out.newLine();
342 out.write("normVar: ");
343 out.newLine();
344 for (int i=0; i < length; i++){
345 out.write(Double.toString(normVar[i]));
346 out.newLine();
348 out.newLine();
349 out.write("image size: ");
350 out.newLine();
351 out.write(width + " " + height);
352 out.close();
354 catch(final IOException e){System.out.println("IOException");}
356 catch(final FileNotFoundException e){System.out.println("File not found!");}
359 public void load(final String filename){
360 try{
361 final BufferedReader in = new BufferedReader(new FileReader(filename));
362 try{
363 String line = in.readLine(); //comment;
364 dimension = Integer.parseInt(in.readLine());
365 line = in.readLine(); //comment;
366 line = in.readLine(); //comment;
367 length = Integer.parseInt(in.readLine());
368 line = in.readLine(); //comment;
369 line = in.readLine(); //comment;
371 beta = new double[length][2];
373 for (int i=0; i < length; i++){
374 line = in.readLine();
375 final int ind = line.indexOf(" ");
376 beta[i][0] = Double.parseDouble(line.substring(0, ind));
377 beta[i][1] = Double.parseDouble(line.substring(ind+4));
380 line = in.readLine(); //comment;
381 line = in.readLine(); //comment;
383 normMean = new double[length];
385 for (int i=0; i < length; i++){
386 normMean[i]=Double.parseDouble(in.readLine());
389 line = in.readLine(); //comment;
390 line = in.readLine(); //comment;
392 normVar = new double[length];
394 for (int i=0; i < length; i++){
395 normVar[i]=Double.parseDouble(in.readLine());
397 line = in.readLine(); //comment;
398 line = in.readLine(); //comment;
399 line = in.readLine();
400 final int ind = line.indexOf(" ");
401 width = Integer.parseInt(line.substring(0, ind));
402 height = Integer.parseInt(line.substring(ind+4));
403 in.close();
405 print();
407 catch(final IOException e){System.out.println("IOException");}
409 catch(final FileNotFoundException e){System.out.println("File not found!");}
412 public ImageProcessor[] transform(final ImageProcessor ip){
413 if (!precalculated)
414 this.precalculateTransfom();
416 final ImageProcessor newIp = ip.createProcessor(ip.getWidth(), ip.getHeight());
417 if (ip instanceof ColorProcessor) ip.max(0);
418 final ImageProcessor maskIp = new ByteProcessor(ip.getWidth(),ip.getHeight());
420 for (int x=0; x < width; x++){
421 for (int y=0; y < height; y++){
422 if (transField[x][y][0] == -1){
423 continue;
425 newIp.set(x, y, (int) ip.getInterpolatedPixel((int)transField[x][y][0],(int)transField[x][y][1]));
426 maskIp.set(x,y,255);
429 return new ImageProcessor[]{newIp, maskIp};
432 public double[][] kernelExpandMatrixNormalize(final double positions[][]){
433 normMean = new double[length];
434 normVar = new double[length];
436 for (int i=0; i < length; i++){
437 normMean[i] = 0;
438 normVar[i] = 1;
441 final double expanded[][] = new double[positions.length][length];
443 for (int i=0; i < positions.length; i++){
444 expanded[i] = kernelExpand(positions[i]);
447 for (int i=0; i < length; i++){
448 double mean = 0;
449 double var = 0;
450 for (int j=0; j < expanded.length; j++){
451 mean += expanded[j][i];
454 mean /= expanded.length;
456 for (int j=0; j < expanded.length; j++){
457 var += (expanded[j][i] - mean)*(expanded[j][i] - mean);
459 var /= (expanded.length -1);
460 var = Math.sqrt(var);
462 normMean[i] = mean;
463 normVar[i] = var;
466 return kernelExpandMatrix(positions);
470 //this function uses the parameters already stored
471 //in this object to normalize the positions given.
472 public double[][] kernelExpandMatrix(final double positions[][]){
475 final double expanded[][] = new double[positions.length][length];
477 for (int i=0; i < positions.length; i++){
478 expanded[i] = kernelExpand(positions[i]);
481 return expanded;
485 public void inverseTransform(final double range[][]){
486 Matrix expanded = new Matrix(kernelExpandMatrix(range));
487 final Matrix b = new Matrix(beta);
489 final Matrix transformed = expanded.times(b);
490 expanded = new Matrix(kernelExpandMatrixNormalize(transformed.getArray()));
492 final Matrix r = new Matrix(range);
493 final Matrix invBeta = expanded.transpose().times(expanded).inverse().times(expanded.transpose()).times(r);
494 setBeta(invBeta.getArray());
497 //FIXME this takes way too much memory
498 public void visualize(){
500 final int density = Math.max(width,height)/32;
501 final int border = Math.max(width,height)/8;
503 final double[][] orig = new double[width * height][2];
504 final double[][] trans = new double[height * width][2];
505 final double[][] gridOrigVert = new double[width*height][2];
506 final double[][] gridTransVert = new double[width*height][2];
507 final double[][] gridOrigHor = new double[width*height][2];
508 final double[][] gridTransHor = new double[width*height][2];
510 final FloatProcessor magnitude = new FloatProcessor(width, height);
511 final FloatProcessor angle = new FloatProcessor(width, height);
512 final ColorProcessor quiver = new ColorProcessor(width, height);
513 final ByteProcessor empty = new ByteProcessor(width+2*border, height+2*border);
514 quiver.setLineWidth(1);
515 quiver.setColor(Color.green);
517 final GeneralPath quiverField = new GeneralPath();
519 float minM = 1000, maxM = 0;
520 float minArc = 5, maxArc = -6;
521 int countVert = 0, countHor = 0, countHorWhole = 0;
523 for (int i=0; i < width; i++){
524 countHor = 0;
525 for (int j=0; j < height; j++){
526 final double[] position = {(double) i,(double) j};
527 final double[] posExpanded = kernelExpand(position);
528 final double[] newPosition = multiply(beta, posExpanded);
530 orig[i*j][0] = position[0];
531 orig[i*j][1] = position[1];
533 trans[i*j][0] = newPosition[0];
534 trans[i*j][1] = newPosition[1];
536 double m = (position[0] - newPosition[0]) * (position[0] - newPosition[0]);
537 m += (position[1] - newPosition[1]) * (position[1] - newPosition[1]);
538 m = Math.sqrt(m);
539 magnitude.setf(i,j, (float) m);
540 minM = Math.min(minM, (float) m);
541 maxM = Math.max(maxM, (float) m);
543 final double a = Math.atan2(position[0] - newPosition[0], position[1] - newPosition[1]);
544 minArc = Math.min(minArc, (float) a);
545 maxArc = Math.max(maxArc, (float) a);
546 angle.setf(i,j, (float) a);
548 if (i%density == 0 && j%density == 0)
549 drawQuiverField(quiverField, position[0], position[1], newPosition[0], newPosition[1]);
550 if (i%density == 0){
551 gridOrigVert[countVert][0] = position[0] + border;
552 gridOrigVert[countVert][1] = position[1] + border;
553 gridTransVert[countVert][0] = newPosition[0] + border;
554 gridTransVert[countVert][1] = newPosition[1] + border;
555 countVert++;
557 if (j%density == 0){
558 gridOrigHor[countHor*width+i][0] = position[0] + border;
559 gridOrigHor[countHor*width+i][1] = position[1] + border;
560 gridTransHor[countHor*width+i][0] = newPosition[0] + border;
561 gridTransHor[countHor*width+i][1] = newPosition[1] + border;
562 countHor++;
563 countHorWhole++;
568 magnitude.setMinAndMax(minM, maxM);
569 angle.setMinAndMax(minArc, maxArc);
570 //System.out.println(" " + minArc + " " + maxArc);
572 final ImagePlus magImg = new ImagePlus("Magnitude of Distortion Field", magnitude);
573 magImg.show();
575 // ImagePlus angleImg = new ImagePlus("Angle of Distortion Field Vectors", angle);
576 // angleImg.show();
578 final ImagePlus quiverImg = new ImagePlus("Quiver Plot of Distortion Field", magnitude);
579 quiverImg.show();
580 quiverImg.getCanvas().setDisplayList(quiverField, Color.green, null );
581 quiverImg.updateAndDraw();
583 // GeneralPath gridOrig = new GeneralPath();
584 // drawGrid(gridOrig, gridOrigVert, countVert, height);
585 // drawGrid(gridOrig, gridOrigHor, countHorWhole, width);
586 // ImagePlus gridImgOrig = new ImagePlus("Distortion Grid", empty);
587 // gridImgOrig.show();
588 // gridImgOrig.getCanvas().setDisplayList(gridOrig, Color.green, null );
589 // gridImgOrig.updateAndDraw();
591 final GeneralPath gridTrans = new GeneralPath();
592 drawGrid(gridTrans, gridTransVert, countVert, height);
593 drawGrid(gridTrans, gridTransHor, countHorWhole, width);
594 final ImagePlus gridImgTrans = new ImagePlus("Distortion Grid", empty);
595 gridImgTrans.show();
596 gridImgTrans.getCanvas().setDisplayList(gridTrans, Color.green, null );
597 gridImgTrans.updateAndDraw();
599 //new FileSaver(quiverImg.getCanvas().imp).saveAsTiff("QuiverCanvas.tif");
600 new FileSaver(quiverImg).saveAsTiff("QuiverImPs.tif");
602 System.out.println("FINISHED");
606 public void visualizeSmall(final double lambda){
607 final int density = Math.max(width,height)/32;
609 final double[][] orig = new double[2][width * height];
610 final double[][] trans = new double[2][height * width];
612 final FloatProcessor magnitude = new FloatProcessor(width, height);
614 final GeneralPath quiverField = new GeneralPath();
616 float minM = 1000, maxM = 0;
617 final float minArc = 5, maxArc = -6;
618 final int countVert = 0;
619 int countHor = 0;
620 final int countHorWhole = 0;
622 for (int i=0; i < width; i++){
623 countHor = 0;
624 for (int j=0; j < height; j++){
625 final double[] position = {(double) i,(double) j};
626 final double[] posExpanded = kernelExpand(position);
627 final double[] newPosition = multiply(beta, posExpanded);
629 orig[0][i*j] = position[0];
630 orig[1][i*j] = position[1];
632 trans[0][i*j] = newPosition[0];
633 trans[1][i*j] = newPosition[1];
635 double m = (position[0] - newPosition[0]) * (position[0] - newPosition[0]);
636 m += (position[1] - newPosition[1]) * (position[1] - newPosition[1]);
637 m = Math.sqrt(m);
638 magnitude.setf(i,j, (float) m);
639 minM = Math.min(minM, (float) m);
640 maxM = Math.max(maxM, (float) m);
642 if (i%density == 0 && j%density == 0)
643 drawQuiverField(quiverField, position[0], position[1], newPosition[0], newPosition[1]);
647 magnitude.setMinAndMax(minM, maxM);
648 final ImagePlus quiverImg = new ImagePlus("Quiver Plot for lambda = "+lambda, magnitude);
649 quiverImg.show();
650 quiverImg.getCanvas().setDisplayList(quiverField, Color.green, null );
651 quiverImg.updateAndDraw();
653 System.out.println("FINISHED");
657 public static void drawGrid(final GeneralPath g, final double[][] points, final int count, final int s){
658 for (int i=0; i < count - 1; i++){
659 if ((i+1)%s != 0){
660 g.moveTo((float)points[i][0], (float)points[i][1]);
661 g.lineTo((float)points[i+1][0], (float)points[i+1][1]);
666 public static void drawQuiverField(final GeneralPath qf, final double x1, final double y1, final double x2, final double y2)
668 qf.moveTo((float)x1, (float)y1);
669 qf.lineTo((float)x2, (float)y2);
672 public int getWidth(){
673 return width;
676 public int getHeight(){
677 return height;
681 * TODO Make this more efficient
683 @Override
684 final public NonLinearTransform copy()
686 final NonLinearTransform t = new NonLinearTransform();
687 t.init( toDataString() );
688 return t;
691 public void set( final NonLinearTransform nlt )
693 this.dimension = nlt.dimension;
694 this.height = nlt.height;
695 this.length = nlt.length;
696 this.precalculated = nlt.precalculated;
697 this.width = nlt.width;
699 /* arrays by deep cloning */
700 this.beta = new double[ nlt.beta.length ][];
701 for ( int i = 0; i < nlt.beta.length; ++i )
702 this.beta[ i ] = nlt.beta[ i ].clone();
704 this.normMean = nlt.normMean.clone();
705 this.normVar = nlt.normVar.clone();
706 this.transField = new double[ nlt.transField.length ][][];
708 for ( int a = 0; a < nlt.transField.length; ++a )
710 this.transField[ a ] = new double[ nlt.transField[ a ].length ][];
711 for ( int b = 0; b < nlt.transField[ a ].length; ++b )
712 this.transField[ a ][ b ] = nlt.transField[ a ][ b ].clone();