Forgot to update the javadoc.
[weka.git] / src / main / java / weka / classifiers / functions / SPegasos.java
blobde5b6e94f04b96eca56ef4e7e1bce60f4f157a15
1 /*
2 * This program is free software; you can redistribute it and/or modify
3 * it under the terms of the GNU General Public License as published by
4 * the Free Software Foundation; either version 2 of the License, or
5 * (at your option) any later version.
7 * This program is distributed in the hope that it will be useful,
8 * but WITHOUT ANY WARRANTY; without even the implied warranty of
9 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 * GNU General Public License for more details.
12 * You should have received a copy of the GNU General Public License
13 * along with this program; if not, write to the Free Software
14 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
18 * SPegasos.java
19 * Copyright (C) 2009 University of Waikato, Hamilton, New Zealand
23 package weka.classifiers.functions;
25 import java.util.ArrayList;
26 import java.util.Enumeration;
27 import java.util.Vector;
29 import weka.classifiers.AbstractClassifier;
30 import weka.classifiers.UpdateableClassifier;
31 import weka.core.Capabilities;
32 import weka.core.Instance;
33 import weka.core.Instances;
34 import weka.core.Option;
35 import weka.core.OptionHandler;
36 import weka.core.RevisionUtils;
37 import weka.core.SelectedTag;
38 import weka.core.Tag;
39 import weka.core.TechnicalInformation;
40 import weka.core.TechnicalInformationHandler;
41 import weka.core.Utils;
42 import weka.core.Capabilities.Capability;
43 import weka.core.TechnicalInformation.Field;
44 import weka.core.TechnicalInformation.Type;
45 import weka.filters.Filter;
46 import weka.filters.unsupervised.attribute.NominalToBinary;
47 import weka.filters.unsupervised.attribute.ReplaceMissingValues;
48 import weka.filters.unsupervised.attribute.Normalize;
50 /**
51 <!-- globalinfo-start -->
52 * Implements the stochastic variant of the Pegasos (Primal Estimated sub-GrAdient SOlver for SVM) method of Shalev-Shwartz et al. (2007). This implementation globally replaces all missing values and transforms nominal attributes into binary ones. It also normalizes all attributes, so the coefficients in the output are based on the normalized data. Can either minimize the hinge loss (SVM) or log loss (logistic regression). For more information, see<br/>
53 * <br/>
54 * S. Shalev-Shwartz, Y. Singer, N. Srebro: Pegasos: Primal Estimated sub-GrAdient SOlver for SVM. In: 24th International Conference on MachineLearning, 807-814, 2007.
55 * <p/>
56 <!-- globalinfo-end -->
58 <!-- technical-bibtex-start -->
59 * BibTeX:
60 * <pre>
61 * &#64;inproceedings{Shalev-Shwartz2007,
62 * author = {S. Shalev-Shwartz and Y. Singer and N. Srebro},
63 * booktitle = {24th International Conference on MachineLearning},
64 * pages = {807-814},
65 * title = {Pegasos: Primal Estimated sub-GrAdient SOlver for SVM},
66 * year = {2007}
67 * }
68 * </pre>
69 * <p/>
70 <!-- technical-bibtex-end -->
72 <!-- options-start -->
73 * Valid options are: <p/>
75 * <pre> -F
76 * Set the loss function to minimize. 0 = hinge loss (SVM), 1 = log loss (logistic regression).
77 * (default = 0)</pre>
79 * <pre> -L &lt;double&gt;
80 * The lambda regularization constant (default = 0.0001)</pre>
82 * <pre> -E &lt;integer&gt;
83 * The number of epochs to perform (batch learning only, default = 500)</pre>
85 * <pre> -N
86 * Don't normalize the data</pre>
88 * <pre> -M
89 * Don't replace missing values</pre>
91 <!-- options-end -->
93 * @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
94 * @version $Revision$
97 public class SPegasos extends AbstractClassifier
98 implements TechnicalInformationHandler, UpdateableClassifier,
99 OptionHandler {
101 /** For serialization */
102 private static final long serialVersionUID = -3732968666673530290L;
104 /** Replace missing values */
105 protected ReplaceMissingValues m_replaceMissing;
107 /** Convert nominal attributes to numerically coded binary ones */
108 protected NominalToBinary m_nominalToBinary;
110 /** Normalize the training data */
111 protected Normalize m_normalize;
113 /** The regularization parameter */
114 protected double m_lambda = 0.0001;
116 /** Stores the weights (+ bias in the last element) */
117 protected double[] m_weights;
119 /** Holds the current iteration number */
120 protected double m_t;
123 * The number of epochs to perform (batch learning). Total iterations is
124 * m_epochs * num instances
126 protected int m_epochs = 500;
128 /**
129 * Turn off normalization of the input data. This option gets
130 * forced for incremental training.
132 protected boolean m_dontNormalize = false;
135 * Turn off global replacement of missing values. Missing values
136 * will be ignored instead. This option gets forced for
137 * incremental training.
139 protected boolean m_dontReplaceMissing = false;
141 /** Holds the header of the training data */
142 protected Instances m_data;
145 * Returns default capabilities of the classifier.
147 * @return the capabilities of this classifier
149 public Capabilities getCapabilities() {
150 Capabilities result = super.getCapabilities();
151 result.disableAll();
153 //attributes
154 result.enable(Capability.NOMINAL_ATTRIBUTES);
155 result.enable(Capability.NUMERIC_ATTRIBUTES);
156 result.enable(Capability.MISSING_VALUES);
158 // class
159 result.enable(Capability.BINARY_CLASS);
160 result.enable(Capability.MISSING_CLASS_VALUES);
162 // instances
163 result.setMinimumNumberInstances(0);
165 return result;
169 * Returns the tip text for this property
171 * @return tip text for this property suitable for
172 * displaying in the explorer/experimenter gui
174 public String lambdaTipText() {
175 return "The regularization constant. (default = 0.0001)";
179 * Set the value of lambda to use
181 * @param lambda the value of lambda to use
183 public void setLambda(double lambda) {
184 m_lambda = lambda;
188 * Get the current value of lambda
190 * @return the current value of lambda
192 public double getLambda() {
193 return m_lambda;
197 * Returns the tip text for this property
199 * @return tip text for this property suitable for
200 * displaying in the explorer/experimenter gui
202 public String epochsTipText() {
203 return "The number of epochs to perform (batch learning). " +
204 "The total number of iterations is epochs * num" +
205 " instances.";
209 * Set the number of epochs to use
211 * @param e the number of epochs to use
213 public void setEpochs(int e) {
214 m_epochs = e;
218 * Get current number of epochs
220 * @return the current number of epochs
222 public int getEpochs() {
223 return m_epochs;
227 * Turn normalization off/on.
229 * @param m true if normalization is to be disabled.
231 public void setDontNormalize(boolean m) {
232 m_dontNormalize = m;
236 * Get whether normalization has been turned off.
238 * @return true if normalization has been disabled.
240 public boolean getDontNormalize() {
241 return m_dontNormalize;
245 * Returns the tip text for this property
247 * @return tip text for this property suitable for
248 * displaying in the explorer/experimenter gui
250 public String dontNormalizeTipText() {
251 return "Turn normalization off";
255 * Turn global replacement of missing values off/on. If turned off,
256 * then missing values are effectively ignored.
258 * @param m true if global replacement of missing values is to be
259 * turned off.
261 public void setDontReplaceMissing(boolean m) {
262 m_dontReplaceMissing = m;
266 * Get whether global replacement of missing values has been
267 * disabled.
269 * @return true if global replacement of missing values has been turned
270 * off
272 public boolean getDontReplaceMissing() {
273 return m_dontReplaceMissing;
277 * Returns the tip text for this property
279 * @return tip text for this property suitable for
280 * displaying in the explorer/experimenter gui
282 public String dontReplaceMissingTipText() {
283 return "Turn off global replacement of missing values";
287 * Set the loss function to use.
289 * @param function the loss function to use.
291 public void setLossFunction(SelectedTag function) {
292 if (function.getTags() == TAGS_SELECTION) {
293 m_loss = function.getSelectedTag().getID();
298 * Get the current loss function.
300 * @return the current loss function.
302 public SelectedTag getLossFunction() {
303 return new SelectedTag(m_loss, TAGS_SELECTION);
307 * Returns the tip text for this property
309 * @return tip text for this property suitable for
310 * displaying in the explorer/experimenter gui
312 public String lossFunctionTipText() {
313 return "The loss function to use. Hinge loss (SVM) " +
314 "or log loss (logistic regression).";
318 * Returns an enumeration describing the available options.
320 * @return an enumeration of all the available options.
322 public Enumeration<Option> listOptions() {
324 Vector<Option> newVector = new Vector<Option>();
326 newVector.add(new Option("\tSet the loss function to minimize. 0 = " +
327 "hinge loss (SVM), 1 = log loss (logistic regression).\n" +
328 "\t(default = 0)", "F", 1, "-F"));
329 newVector.add(new Option("\tThe lambda regularization constant " +
330 "(default = 0.0001)",
331 "L", 1, "-L <double>"));
332 newVector.add(new Option("\tThe number of epochs to perform (" +
333 "batch learning only, default = 500)", "E", 1,
334 "-E <integer>"));
335 newVector.add(new Option("\tDon't normalize the data", "N", 0, "-N"));
336 newVector.add(new Option("\tDon't replace missing values", "M", 0, "-M"));
338 return newVector.elements();
342 * Parses a given list of options. <p/>
344 <!-- options-start -->
345 * Valid options are: <p/>
347 * <pre> -F
348 * Set the loss function to minimize. 0 = hinge loss (SVM), 1 = log loss (logistic regression).
349 * (default = 0)</pre>
351 * <pre> -L &lt;double&gt;
352 * The lambda regularization constant (default = 0.0001)</pre>
354 * <pre> -E &lt;integer&gt;
355 * The number of epochs to perform (batch learning only, default = 500)</pre>
357 * <pre> -N
358 * Don't normalize the data</pre>
360 * <pre> -M
361 * Don't replace missing values</pre>
363 <!-- options-end -->
365 * @param options the list of options as an array of strings
366 * @throws Exception if an option is not supported
368 public void setOptions(String[] options) throws Exception {
369 reset();
371 String lossString = Utils.getOption('F', options);
372 if (lossString.length() != 0) {
373 setLossFunction(new SelectedTag(Integer.parseInt(lossString),
374 TAGS_SELECTION));
375 } else {
376 setLossFunction(new SelectedTag(HINGE, TAGS_SELECTION));
379 String lambdaString = Utils.getOption('L', options);
380 if (lambdaString.length() > 0) {
381 setLambda(Double.parseDouble(lambdaString));
384 String epochsString = Utils.getOption("E", options);
385 if (epochsString.length() > 0) {
386 setEpochs(Integer.parseInt(epochsString));
389 setDontNormalize(Utils.getFlag("N", options));
390 setDontReplaceMissing(Utils.getFlag('M', options));
394 * Gets the current settings of the classifier.
396 * @return an array of strings suitable for passing to setOptions
398 public String[] getOptions() {
399 ArrayList<String> options = new ArrayList<String>();
401 options.add("-F"); options.add("" + getLossFunction().getSelectedTag().getID());
402 options.add("-L"); options.add("" + getLambda());
403 options.add("-E"); options.add("" + getEpochs());
404 if (getDontNormalize()) {
405 options.add("-N");
407 if (getDontReplaceMissing()) {
408 options.add("-M");
411 return options.toArray(new String[1]);
415 * Returns a string describing classifier
416 * @return a description suitable for
417 * displaying in the explorer/experimenter gui
419 public String globalInfo() {
420 return "Implements the stochastic variant of the Pegasos" +
421 " (Primal Estimated sub-GrAdient SOlver for SVM)" +
422 " method of Shalev-Shwartz et al. (2007). This implementation" +
423 " globally replaces all missing values and transforms nominal" +
424 " attributes into binary ones. It also normalizes all attributes," +
425 " so the coefficients in the output are based on the normalized" +
426 " data. Can either minimize the hinge loss (SVM) or log loss (" +
427 "logistic regression). For more information, see\n\n" +
428 getTechnicalInformation().toString();
432 * Returns an instance of a TechnicalInformation object, containing
433 * detailed information about the technical background of this class,
434 * e.g., paper reference or book this class is based on.
436 * @return the technical information about this class
438 public TechnicalInformation getTechnicalInformation() {
439 TechnicalInformation result;
441 result = new TechnicalInformation(Type.INPROCEEDINGS);
442 result.setValue(Field.AUTHOR, "S. Shalev-Shwartz and Y. Singer and N. Srebro");
443 result.setValue(Field.YEAR, "2007");
444 result.setValue(Field.TITLE, "Pegasos: Primal Estimated sub-GrAdient " +
445 "SOlver for SVM");
446 result.setValue(Field.BOOKTITLE, "24th International Conference on Machine" +
447 "Learning");
448 result.setValue(Field.PAGES, "807-814");
450 return result;
454 * Reset the classifier.
456 public void reset() {
457 m_t = 1;
458 m_weights = null;
462 * Method for building the classifier.
464 * @param data the set of training instances.
465 * @throws Exception if the classifier can't be built successfully.
467 public void buildClassifier(Instances data) throws Exception {
468 reset();
470 // can classifier handle the data?
471 getCapabilities().testWithFail(data);
473 data = new Instances(data);
474 data.deleteWithMissingClass();
476 if (data.numInstances() > 0 && !m_dontReplaceMissing) {
477 m_replaceMissing = new ReplaceMissingValues();
478 m_replaceMissing.setInputFormat(data);
479 data = Filter.useFilter(data, m_replaceMissing);
482 // check for only numeric attributes
483 boolean onlyNumeric = true;
484 for (int i = 0; i < data.numAttributes(); i++) {
485 if (i != data.classIndex()) {
486 if (!data.attribute(i).isNumeric()) {
487 onlyNumeric = false;
488 break;
493 if (!onlyNumeric) {
494 m_nominalToBinary = new NominalToBinary();
495 m_nominalToBinary.setInputFormat(data);
496 data = Filter.useFilter(data, m_nominalToBinary);
499 if (!m_dontNormalize && data.numInstances() > 0) {
501 m_normalize = new Normalize();
502 m_normalize.setInputFormat(data);
503 data = Filter.useFilter(data, m_normalize);
506 m_weights = new double[data.numAttributes() + 1];
507 m_data = new Instances(data, 0);
509 if (data.numInstances() > 0) {
510 train(data);
514 protected static final int HINGE = 0;
515 protected static final int LOGLOSS = 1;
517 /** The current loss function to minimize */
518 protected int m_loss = HINGE;
520 /** Loss functions to choose from */
521 public static final Tag [] TAGS_SELECTION = {
522 new Tag(HINGE, "Hinge loss (SVM)"),
523 new Tag(LOGLOSS, "Log loss (logistic regression)")
526 protected double dloss(double z) {
527 if (m_loss == HINGE) {
528 return (z < 1) ? 1 : 0;
531 // log loss
532 if (z < 0) {
533 return 1.0 / (Math.exp(z) + 1.0);
534 } else {
535 double t = Math.exp(-z);
536 return t / (t + 1);
540 private void train(Instances data) {
541 for (int e = 0; e < m_epochs; e++) {
542 for (int i = 0; i < data.numInstances(); i++) {
543 Instance instance = data.instance(i);
545 double learningRate = 1.0 / (m_lambda * m_t);
546 //double scale = 1.0 - learningRate * m_lambda;
547 double scale = 1.0 - 1.0 / m_t;
548 double y = (instance.classValue() == 0) ? -1 : 1;
549 double wx = dotProd(instance, m_weights, instance.classIndex());
550 double z = y * (wx + m_weights[m_weights.length - 1]);
553 if (m_loss == LOGLOSS || (z < 1)) {
554 double delta = learningRate * dloss(z);
555 int n1 = instance.numValues();
556 int n2 = data.numAttributes();
557 for (int p1 = 0, p2 = 0; p2 < n2;) {
558 int indS = 0;
559 indS = (p1 < n1) ? instance.index(p1) : indS;
560 int indP = p2;
561 if (indP != data.classIndex()) {
562 m_weights[indP] *= scale;
564 if (indS == indP) {
565 if (indS != data.classIndex() &&
566 !instance.isMissingSparse(p1)) {
567 //double m = learningRate * (instance.valueSparse(p1) * y);
568 double m = delta * (instance.valueSparse(p1) * y);
569 m_weights[indS] += m;
571 p1++;
573 p2++;
576 // update the bias
577 m_weights[m_weights.length - 1] += delta * y;
579 double norm = 0;
580 for (int k = 0; k < m_weights.length; k++) {
581 if (k != data.classIndex()) {
582 norm += (m_weights[k] * m_weights[k]);
585 norm = Math.sqrt(norm);
587 double scale2 = Math.min(1.0, (1.0 / (Math.sqrt(m_lambda) * norm)));
588 if (scale2 < 1.0) {
589 for (int j = 0; j < m_weights.length; j++) {
590 m_weights[j] *= scale2;
594 m_t++;
599 protected static double dotProd(Instance inst1, double[] weights, int classIndex) {
600 double result = 0;
602 int n1 = inst1.numValues();
603 int n2 = weights.length - 1;
605 for (int p1 = 0, p2 = 0; p1 < n1 && p2 < n2;) {
606 int ind1 = inst1.index(p1);
607 int ind2 = p2;
608 if (ind1 == ind2) {
609 if (ind1 != classIndex && !inst1.isMissingSparse(p1)) {
610 result += inst1.valueSparse(p1) * weights[p2];
612 p1++;
613 p2++;
614 } else if (ind1 > ind2) {
615 p2++;
616 } else {
617 p1++;
620 return (result);
624 * Updates the classifier with the given instance.
626 * @param instance the new training instance to include in the model
627 * @exception Exception if the instance could not be incorporated in
628 * the model.
630 public void updateClassifier(Instance instance) throws Exception {
631 if (!instance.classIsMissing()) {
632 double learningRate = 1.0 / (m_lambda * m_t);
633 //double scale = 1.0 - learningRate * m_lambda;
634 double scale = 1.0 - 1.0 / m_t;
635 double y = (instance.classValue() == 0) ? -1 : 1;
636 double wx = dotProd(instance, m_weights, instance.classIndex());
637 double z = y * (wx + m_weights[m_weights.length - 1]);
639 for (int j = 0; j < m_weights.length; j++) {
640 m_weights[j] *= scale;
643 if (m_loss == LOGLOSS || (z < 1)) {
644 double delta = learningRate * dloss(z);
645 int n1 = instance.numValues();
646 int n2 = instance.numAttributes();
647 for (int p1 = 0, p2 = 0; p2 < n2;) {
648 int indS = 0;
649 indS = (p1 < n1) ? instance.index(p1) : indS;
650 int indP = p2;
651 if (indP != instance.classIndex()) {
652 m_weights[indP] *= scale;
654 if (indS == indP) {
655 if (indS != instance.classIndex() &&
656 !instance.isMissingSparse(p1)) {
657 double m = delta * (instance.valueSparse(p1) * y);
658 m_weights[indS] += m;
660 p1++;
662 p2++;
665 // update the bias
666 m_weights[m_weights.length - 1] += delta * y;
668 double norm = 0;
669 for (int k = 0; k < m_weights.length; k++) {
670 if (k != instance.classIndex()) {
671 norm += (m_weights[k] * m_weights[k]);
674 norm = Math.sqrt(norm);
676 double scale2 = Math.min(1.0, (1.0 / (Math.sqrt(m_lambda) * norm)));
677 if (scale2 < 1.0) {
678 for (int j = 0; j < m_weights.length; j++) {
679 m_weights[j] *= scale2;
684 m_t++;
689 * Computes the distribution for a given instance
691 * @param instance the instance for which distribution is computed
692 * @return the distribution
693 * @throws Exception if the distribution can't be computed successfully
695 public double[] distributionForInstance(Instance inst) throws Exception {
696 double[] result = new double[2];
698 if (m_replaceMissing != null) {
699 m_replaceMissing.input(inst);
700 inst = m_replaceMissing.output();
703 if (m_nominalToBinary != null) {
704 m_nominalToBinary.input(inst);
705 inst = m_nominalToBinary.output();
708 if (m_normalize != null){
709 m_normalize.input(inst);
710 inst = m_normalize.output();
713 double wx = dotProd(inst, m_weights, inst.classIndex());// * m_wScale;
714 double z = (wx + m_weights[m_weights.length - 1]);
715 //System.out.print("" + z + ": ");
716 // System.out.println(1.0 / (1.0 + Math.exp(-z)));
717 if (z <= 0) {
718 // z = 0;
719 if (m_loss == LOGLOSS) {
720 result[0] = 1.0 / (1.0 + Math.exp(z));
721 result[1] = 1.0 - result[0];
722 } else {
723 result[0] = 1;
725 } else {
726 if (m_loss == LOGLOSS) {
727 result[1] = 1.0 / (1.0 + Math.exp(-z));
728 result[0] = 1.0 - result[1];
729 } else {
730 result[1] = 1;
733 return result;
738 * Prints out the classifier.
740 * @return a description of the classifier as a string
742 public String toString() {
743 if (m_weights == null) {
744 return "SPegasos: No model built yet.\n";
746 StringBuffer buff = new StringBuffer();
747 buff.append("Loss function: ");
748 if (m_loss == HINGE) {
749 buff.append("Hinge loss (SVM)\n\n");
750 } else {
751 buff.append("Log loss (logistic regression)\n\n");
753 int printed = 0;
755 for (int i = 0 ; i < m_weights.length - 1; i++) {
756 if (i != m_data.classIndex()) {
757 if (printed > 0) {
758 buff.append(" + ");
759 } else {
760 buff.append(" ");
763 buff.append(Utils.doubleToString(m_weights[i], 12, 4) +
764 " " + ((m_normalize != null) ? "(normalized) " : "")
765 + m_data.attribute(i).name() + "\n");
767 printed++;
771 if (m_weights[m_weights.length - 1] > 0) {
772 buff.append(" + " + Utils.doubleToString(m_weights[m_weights.length - 1], 12, 4));
773 } else {
774 buff.append(" - " + Utils.doubleToString(-m_weights[m_weights.length - 1], 12, 4));
777 return buff.toString();
781 * Returns the revision string.
783 * @return the revision
785 public String getRevision() {
786 return RevisionUtils.extract("$Revision$");
790 * Main method for testing this class.
792 public static void main(String[] args) {
793 runClassifier(new SPegasos(), args);