2 * This program is free software; you can redistribute it and/or modify
3 * it under the terms of the GNU General Public License as published by
4 * the Free Software Foundation; either version 2 of the License, or
5 * (at your option) any later version.
7 * This program is distributed in the hope that it will be useful,
8 * but WITHOUT ANY WARRANTY; without even the implied warranty of
9 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 * GNU General Public License for more details.
12 * You should have received a copy of the GNU General Public License
13 * along with this program; if not, write to the Free Software
14 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
19 * Copyright (C) 2009 University of Waikato, Hamilton, New Zealand
23 package weka
.classifiers
.functions
;
25 import java
.util
.ArrayList
;
26 import java
.util
.Enumeration
;
27 import java
.util
.Vector
;
29 import weka
.classifiers
.AbstractClassifier
;
30 import weka
.classifiers
.UpdateableClassifier
;
31 import weka
.core
.Capabilities
;
32 import weka
.core
.Instance
;
33 import weka
.core
.Instances
;
34 import weka
.core
.Option
;
35 import weka
.core
.OptionHandler
;
36 import weka
.core
.RevisionUtils
;
37 import weka
.core
.SelectedTag
;
39 import weka
.core
.TechnicalInformation
;
40 import weka
.core
.TechnicalInformationHandler
;
41 import weka
.core
.Utils
;
42 import weka
.core
.Capabilities
.Capability
;
43 import weka
.core
.TechnicalInformation
.Field
;
44 import weka
.core
.TechnicalInformation
.Type
;
45 import weka
.filters
.Filter
;
46 import weka
.filters
.unsupervised
.attribute
.NominalToBinary
;
47 import weka
.filters
.unsupervised
.attribute
.ReplaceMissingValues
;
48 import weka
.filters
.unsupervised
.attribute
.Normalize
;
51 <!-- globalinfo-start -->
52 * Implements the stochastic variant of the Pegasos (Primal Estimated sub-GrAdient SOlver for SVM) method of Shalev-Shwartz et al. (2007). This implementation globally replaces all missing values and transforms nominal attributes into binary ones. It also normalizes all attributes, so the coefficients in the output are based on the normalized data. Can either minimize the hinge loss (SVM) or log loss (logistic regression). For more information, see<br/>
54 * S. Shalev-Shwartz, Y. Singer, N. Srebro: Pegasos: Primal Estimated sub-GrAdient SOlver for SVM. In: 24th International Conference on MachineLearning, 807-814, 2007.
56 <!-- globalinfo-end -->
58 <!-- technical-bibtex-start -->
61 * @inproceedings{Shalev-Shwartz2007,
62 * author = {S. Shalev-Shwartz and Y. Singer and N. Srebro},
63 * booktitle = {24th International Conference on MachineLearning},
65 * title = {Pegasos: Primal Estimated sub-GrAdient SOlver for SVM},
70 <!-- technical-bibtex-end -->
72 <!-- options-start -->
73 * Valid options are: <p/>
76 * Set the loss function to minimize. 0 = hinge loss (SVM), 1 = log loss (logistic regression).
79 * <pre> -L <double>
80 * The lambda regularization constant (default = 0.0001)</pre>
82 * <pre> -E <integer>
83 * The number of epochs to perform (batch learning only, default = 500)</pre>
86 * Don't normalize the data</pre>
89 * Don't replace missing values</pre>
93 * @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
97 public class SPegasos
extends AbstractClassifier
98 implements TechnicalInformationHandler
, UpdateableClassifier
,
101 /** For serialization */
102 private static final long serialVersionUID
= -3732968666673530290L;
104 /** Replace missing values */
105 protected ReplaceMissingValues m_replaceMissing
;
107 /** Convert nominal attributes to numerically coded binary ones */
108 protected NominalToBinary m_nominalToBinary
;
110 /** Normalize the training data */
111 protected Normalize m_normalize
;
113 /** The regularization parameter */
114 protected double m_lambda
= 0.0001;
116 /** Stores the weights (+ bias in the last element) */
117 protected double[] m_weights
;
119 /** Holds the current iteration number */
120 protected double m_t
;
123 * The number of epochs to perform (batch learning). Total iterations is
124 * m_epochs * num instances
126 protected int m_epochs
= 500;
129 * Turn off normalization of the input data. This option gets
130 * forced for incremental training.
132 protected boolean m_dontNormalize
= false;
135 * Turn off global replacement of missing values. Missing values
136 * will be ignored instead. This option gets forced for
137 * incremental training.
139 protected boolean m_dontReplaceMissing
= false;
141 /** Holds the header of the training data */
142 protected Instances m_data
;
145 * Returns default capabilities of the classifier.
147 * @return the capabilities of this classifier
149 public Capabilities
getCapabilities() {
150 Capabilities result
= super.getCapabilities();
154 result
.enable(Capability
.NOMINAL_ATTRIBUTES
);
155 result
.enable(Capability
.NUMERIC_ATTRIBUTES
);
156 result
.enable(Capability
.MISSING_VALUES
);
159 result
.enable(Capability
.BINARY_CLASS
);
160 result
.enable(Capability
.MISSING_CLASS_VALUES
);
163 result
.setMinimumNumberInstances(0);
169 * Returns the tip text for this property
171 * @return tip text for this property suitable for
172 * displaying in the explorer/experimenter gui
174 public String
lambdaTipText() {
175 return "The regularization constant. (default = 0.0001)";
179 * Set the value of lambda to use
181 * @param lambda the value of lambda to use
183 public void setLambda(double lambda
) {
188 * Get the current value of lambda
190 * @return the current value of lambda
192 public double getLambda() {
197 * Returns the tip text for this property
199 * @return tip text for this property suitable for
200 * displaying in the explorer/experimenter gui
202 public String
epochsTipText() {
203 return "The number of epochs to perform (batch learning). " +
204 "The total number of iterations is epochs * num" +
209 * Set the number of epochs to use
211 * @param e the number of epochs to use
213 public void setEpochs(int e
) {
218 * Get current number of epochs
220 * @return the current number of epochs
222 public int getEpochs() {
227 * Turn normalization off/on.
229 * @param m true if normalization is to be disabled.
231 public void setDontNormalize(boolean m
) {
236 * Get whether normalization has been turned off.
238 * @return true if normalization has been disabled.
240 public boolean getDontNormalize() {
241 return m_dontNormalize
;
245 * Returns the tip text for this property
247 * @return tip text for this property suitable for
248 * displaying in the explorer/experimenter gui
250 public String
dontNormalizeTipText() {
251 return "Turn normalization off";
255 * Turn global replacement of missing values off/on. If turned off,
256 * then missing values are effectively ignored.
258 * @param m true if global replacement of missing values is to be
261 public void setDontReplaceMissing(boolean m
) {
262 m_dontReplaceMissing
= m
;
266 * Get whether global replacement of missing values has been
269 * @return true if global replacement of missing values has been turned
272 public boolean getDontReplaceMissing() {
273 return m_dontReplaceMissing
;
277 * Returns the tip text for this property
279 * @return tip text for this property suitable for
280 * displaying in the explorer/experimenter gui
282 public String
dontReplaceMissingTipText() {
283 return "Turn off global replacement of missing values";
287 * Set the loss function to use.
289 * @param function the loss function to use.
291 public void setLossFunction(SelectedTag function
) {
292 if (function
.getTags() == TAGS_SELECTION
) {
293 m_loss
= function
.getSelectedTag().getID();
298 * Get the current loss function.
300 * @return the current loss function.
302 public SelectedTag
getLossFunction() {
303 return new SelectedTag(m_loss
, TAGS_SELECTION
);
307 * Returns the tip text for this property
309 * @return tip text for this property suitable for
310 * displaying in the explorer/experimenter gui
312 public String
lossFunctionTipText() {
313 return "The loss function to use. Hinge loss (SVM) " +
314 "or log loss (logistic regression).";
318 * Returns an enumeration describing the available options.
320 * @return an enumeration of all the available options.
322 public Enumeration
<Option
> listOptions() {
324 Vector
<Option
> newVector
= new Vector
<Option
>();
326 newVector
.add(new Option("\tSet the loss function to minimize. 0 = " +
327 "hinge loss (SVM), 1 = log loss (logistic regression).\n" +
328 "\t(default = 0)", "F", 1, "-F"));
329 newVector
.add(new Option("\tThe lambda regularization constant " +
330 "(default = 0.0001)",
331 "L", 1, "-L <double>"));
332 newVector
.add(new Option("\tThe number of epochs to perform (" +
333 "batch learning only, default = 500)", "E", 1,
335 newVector
.add(new Option("\tDon't normalize the data", "N", 0, "-N"));
336 newVector
.add(new Option("\tDon't replace missing values", "M", 0, "-M"));
338 return newVector
.elements();
342 * Parses a given list of options. <p/>
344 <!-- options-start -->
345 * Valid options are: <p/>
348 * Set the loss function to minimize. 0 = hinge loss (SVM), 1 = log loss (logistic regression).
349 * (default = 0)</pre>
351 * <pre> -L <double>
352 * The lambda regularization constant (default = 0.0001)</pre>
354 * <pre> -E <integer>
355 * The number of epochs to perform (batch learning only, default = 500)</pre>
358 * Don't normalize the data</pre>
361 * Don't replace missing values</pre>
365 * @param options the list of options as an array of strings
366 * @throws Exception if an option is not supported
368 public void setOptions(String
[] options
) throws Exception
{
371 String lossString
= Utils
.getOption('F', options
);
372 if (lossString
.length() != 0) {
373 setLossFunction(new SelectedTag(Integer
.parseInt(lossString
),
376 setLossFunction(new SelectedTag(HINGE
, TAGS_SELECTION
));
379 String lambdaString
= Utils
.getOption('L', options
);
380 if (lambdaString
.length() > 0) {
381 setLambda(Double
.parseDouble(lambdaString
));
384 String epochsString
= Utils
.getOption("E", options
);
385 if (epochsString
.length() > 0) {
386 setEpochs(Integer
.parseInt(epochsString
));
389 setDontNormalize(Utils
.getFlag("N", options
));
390 setDontReplaceMissing(Utils
.getFlag('M', options
));
394 * Gets the current settings of the classifier.
396 * @return an array of strings suitable for passing to setOptions
398 public String
[] getOptions() {
399 ArrayList
<String
> options
= new ArrayList
<String
>();
401 options
.add("-F"); options
.add("" + getLossFunction().getSelectedTag().getID());
402 options
.add("-L"); options
.add("" + getLambda());
403 options
.add("-E"); options
.add("" + getEpochs());
404 if (getDontNormalize()) {
407 if (getDontReplaceMissing()) {
411 return options
.toArray(new String
[1]);
415 * Returns a string describing classifier
416 * @return a description suitable for
417 * displaying in the explorer/experimenter gui
419 public String
globalInfo() {
420 return "Implements the stochastic variant of the Pegasos" +
421 " (Primal Estimated sub-GrAdient SOlver for SVM)" +
422 " method of Shalev-Shwartz et al. (2007). This implementation" +
423 " globally replaces all missing values and transforms nominal" +
424 " attributes into binary ones. It also normalizes all attributes," +
425 " so the coefficients in the output are based on the normalized" +
426 " data. Can either minimize the hinge loss (SVM) or log loss (" +
427 "logistic regression). For more information, see\n\n" +
428 getTechnicalInformation().toString();
432 * Returns an instance of a TechnicalInformation object, containing
433 * detailed information about the technical background of this class,
434 * e.g., paper reference or book this class is based on.
436 * @return the technical information about this class
438 public TechnicalInformation
getTechnicalInformation() {
439 TechnicalInformation result
;
441 result
= new TechnicalInformation(Type
.INPROCEEDINGS
);
442 result
.setValue(Field
.AUTHOR
, "S. Shalev-Shwartz and Y. Singer and N. Srebro");
443 result
.setValue(Field
.YEAR
, "2007");
444 result
.setValue(Field
.TITLE
, "Pegasos: Primal Estimated sub-GrAdient " +
446 result
.setValue(Field
.BOOKTITLE
, "24th International Conference on Machine" +
448 result
.setValue(Field
.PAGES
, "807-814");
454 * Reset the classifier.
456 public void reset() {
462 * Method for building the classifier.
464 * @param data the set of training instances.
465 * @throws Exception if the classifier can't be built successfully.
467 public void buildClassifier(Instances data
) throws Exception
{
470 // can classifier handle the data?
471 getCapabilities().testWithFail(data
);
473 data
= new Instances(data
);
474 data
.deleteWithMissingClass();
476 if (data
.numInstances() > 0 && !m_dontReplaceMissing
) {
477 m_replaceMissing
= new ReplaceMissingValues();
478 m_replaceMissing
.setInputFormat(data
);
479 data
= Filter
.useFilter(data
, m_replaceMissing
);
482 // check for only numeric attributes
483 boolean onlyNumeric
= true;
484 for (int i
= 0; i
< data
.numAttributes(); i
++) {
485 if (i
!= data
.classIndex()) {
486 if (!data
.attribute(i
).isNumeric()) {
494 m_nominalToBinary
= new NominalToBinary();
495 m_nominalToBinary
.setInputFormat(data
);
496 data
= Filter
.useFilter(data
, m_nominalToBinary
);
499 if (!m_dontNormalize
&& data
.numInstances() > 0) {
501 m_normalize
= new Normalize();
502 m_normalize
.setInputFormat(data
);
503 data
= Filter
.useFilter(data
, m_normalize
);
506 m_weights
= new double[data
.numAttributes() + 1];
507 m_data
= new Instances(data
, 0);
509 if (data
.numInstances() > 0) {
514 protected static final int HINGE
= 0;
515 protected static final int LOGLOSS
= 1;
517 /** The current loss function to minimize */
518 protected int m_loss
= HINGE
;
520 /** Loss functions to choose from */
521 public static final Tag
[] TAGS_SELECTION
= {
522 new Tag(HINGE
, "Hinge loss (SVM)"),
523 new Tag(LOGLOSS
, "Log loss (logistic regression)")
526 protected double dloss(double z
) {
527 if (m_loss
== HINGE
) {
528 return (z
< 1) ?
1 : 0;
533 return 1.0 / (Math
.exp(z
) + 1.0);
535 double t
= Math
.exp(-z
);
540 private void train(Instances data
) {
541 for (int e
= 0; e
< m_epochs
; e
++) {
542 for (int i
= 0; i
< data
.numInstances(); i
++) {
543 Instance instance
= data
.instance(i
);
545 double learningRate
= 1.0 / (m_lambda
* m_t
);
546 //double scale = 1.0 - learningRate * m_lambda;
547 double scale
= 1.0 - 1.0 / m_t
;
548 double y
= (instance
.classValue() == 0) ?
-1 : 1;
549 double wx
= dotProd(instance
, m_weights
, instance
.classIndex());
550 double z
= y
* (wx
+ m_weights
[m_weights
.length
- 1]);
553 if (m_loss
== LOGLOSS
|| (z
< 1)) {
554 double delta
= learningRate
* dloss(z
);
555 int n1
= instance
.numValues();
556 int n2
= data
.numAttributes();
557 for (int p1
= 0, p2
= 0; p2
< n2
;) {
559 indS
= (p1
< n1
) ? instance
.index(p1
) : indS
;
561 if (indP
!= data
.classIndex()) {
562 m_weights
[indP
] *= scale
;
565 if (indS
!= data
.classIndex() &&
566 !instance
.isMissingSparse(p1
)) {
567 //double m = learningRate * (instance.valueSparse(p1) * y);
568 double m
= delta
* (instance
.valueSparse(p1
) * y
);
569 m_weights
[indS
] += m
;
577 m_weights
[m_weights
.length
- 1] += delta
* y
;
580 for (int k
= 0; k
< m_weights
.length
; k
++) {
581 if (k
!= data
.classIndex()) {
582 norm
+= (m_weights
[k
] * m_weights
[k
]);
585 norm
= Math
.sqrt(norm
);
587 double scale2
= Math
.min(1.0, (1.0 / (Math
.sqrt(m_lambda
) * norm
)));
589 for (int j
= 0; j
< m_weights
.length
; j
++) {
590 m_weights
[j
] *= scale2
;
599 protected static double dotProd(Instance inst1
, double[] weights
, int classIndex
) {
602 int n1
= inst1
.numValues();
603 int n2
= weights
.length
- 1;
605 for (int p1
= 0, p2
= 0; p1
< n1
&& p2
< n2
;) {
606 int ind1
= inst1
.index(p1
);
609 if (ind1
!= classIndex
&& !inst1
.isMissingSparse(p1
)) {
610 result
+= inst1
.valueSparse(p1
) * weights
[p2
];
614 } else if (ind1
> ind2
) {
624 * Updates the classifier with the given instance.
626 * @param instance the new training instance to include in the model
627 * @exception Exception if the instance could not be incorporated in
630 public void updateClassifier(Instance instance
) throws Exception
{
631 if (!instance
.classIsMissing()) {
632 double learningRate
= 1.0 / (m_lambda
* m_t
);
633 //double scale = 1.0 - learningRate * m_lambda;
634 double scale
= 1.0 - 1.0 / m_t
;
635 double y
= (instance
.classValue() == 0) ?
-1 : 1;
636 double wx
= dotProd(instance
, m_weights
, instance
.classIndex());
637 double z
= y
* (wx
+ m_weights
[m_weights
.length
- 1]);
639 for (int j
= 0; j
< m_weights
.length
; j
++) {
640 m_weights
[j
] *= scale
;
643 if (m_loss
== LOGLOSS
|| (z
< 1)) {
644 double delta
= learningRate
* dloss(z
);
645 int n1
= instance
.numValues();
646 int n2
= instance
.numAttributes();
647 for (int p1
= 0, p2
= 0; p2
< n2
;) {
649 indS
= (p1
< n1
) ? instance
.index(p1
) : indS
;
651 if (indP
!= instance
.classIndex()) {
652 m_weights
[indP
] *= scale
;
655 if (indS
!= instance
.classIndex() &&
656 !instance
.isMissingSparse(p1
)) {
657 double m
= delta
* (instance
.valueSparse(p1
) * y
);
658 m_weights
[indS
] += m
;
666 m_weights
[m_weights
.length
- 1] += delta
* y
;
669 for (int k
= 0; k
< m_weights
.length
; k
++) {
670 if (k
!= instance
.classIndex()) {
671 norm
+= (m_weights
[k
] * m_weights
[k
]);
674 norm
= Math
.sqrt(norm
);
676 double scale2
= Math
.min(1.0, (1.0 / (Math
.sqrt(m_lambda
) * norm
)));
678 for (int j
= 0; j
< m_weights
.length
; j
++) {
679 m_weights
[j
] *= scale2
;
689 * Computes the distribution for a given instance
691 * @param instance the instance for which distribution is computed
692 * @return the distribution
693 * @throws Exception if the distribution can't be computed successfully
695 public double[] distributionForInstance(Instance inst
) throws Exception
{
696 double[] result
= new double[2];
698 if (m_replaceMissing
!= null) {
699 m_replaceMissing
.input(inst
);
700 inst
= m_replaceMissing
.output();
703 if (m_nominalToBinary
!= null) {
704 m_nominalToBinary
.input(inst
);
705 inst
= m_nominalToBinary
.output();
708 if (m_normalize
!= null){
709 m_normalize
.input(inst
);
710 inst
= m_normalize
.output();
713 double wx
= dotProd(inst
, m_weights
, inst
.classIndex());// * m_wScale;
714 double z
= (wx
+ m_weights
[m_weights
.length
- 1]);
715 //System.out.print("" + z + ": ");
716 // System.out.println(1.0 / (1.0 + Math.exp(-z)));
719 if (m_loss
== LOGLOSS
) {
720 result
[0] = 1.0 / (1.0 + Math
.exp(z
));
721 result
[1] = 1.0 - result
[0];
726 if (m_loss
== LOGLOSS
) {
727 result
[1] = 1.0 / (1.0 + Math
.exp(-z
));
728 result
[0] = 1.0 - result
[1];
738 * Prints out the classifier.
740 * @return a description of the classifier as a string
742 public String
toString() {
743 if (m_weights
== null) {
744 return "SPegasos: No model built yet.\n";
746 StringBuffer buff
= new StringBuffer();
747 buff
.append("Loss function: ");
748 if (m_loss
== HINGE
) {
749 buff
.append("Hinge loss (SVM)\n\n");
751 buff
.append("Log loss (logistic regression)\n\n");
755 for (int i
= 0 ; i
< m_weights
.length
- 1; i
++) {
756 if (i
!= m_data
.classIndex()) {
763 buff
.append(Utils
.doubleToString(m_weights
[i
], 12, 4) +
764 " " + ((m_normalize
!= null) ?
"(normalized) " : "")
765 + m_data
.attribute(i
).name() + "\n");
771 if (m_weights
[m_weights
.length
- 1] > 0) {
772 buff
.append(" + " + Utils
.doubleToString(m_weights
[m_weights
.length
- 1], 12, 4));
774 buff
.append(" - " + Utils
.doubleToString(-m_weights
[m_weights
.length
- 1], 12, 4));
777 return buff
.toString();
781 * Returns the revision string.
783 * @return the revision
785 public String
getRevision() {
786 return RevisionUtils
.extract("$Revision$");
790 * Main method for testing this class.
792 public static void main(String
[] args
) {
793 runClassifier(new SPegasos(), args
);