some late comments :)
[gostyle.git] / orange_hacks / stacking.py
blob784f380336a7763af5332ad999f30229a7aa5cbb
1 import Orange
3 class StackedClassificationLearner(Orange.classification.Learner):
4 """Stacking by inference of meta classifier from class probability estimates
5 on cross-validation held-out data for level-0 classifiers developed on held-in data sets.
7 :param learners: level-0 learners.
8 :type learners: list
10 :param meta_learner: meta learner (default: :class:`~Orange.classification.bayes.NaiveLearner`).
11 :type meta_learner: :class:`~Orange.classification.Learner`
13 :param folds: number of iterations (folds) of cross-validation to assemble class probability data for meta learner.
15 :param name: learner name (default: stacking).
16 :type name: string
18 :rtype: :class:`~Orange.ensemble.stacking.StackedClassificationLearner` or
19 :class:`~Orange.ensemble.stacking.StackedClassifier`
20 """
21 def __new__(cls, learners, data=None, weight=0, **kwds):
22 if data is None:
23 self = Orange.classification.Learner.__new__(cls)
24 return self
25 else:
26 self = cls(learners, **kwds)
27 return self(data, weight)
29 def __init__(self, learners, meta_learner=Orange.classification.bayes.NaiveLearner(), folds=10, name='stacking'):
30 self.learners = learners
31 self.meta_learner = meta_learner
32 self.name = name
33 self.folds = folds
35 def __call__(self, data, weight=0):
36 res = Orange.evaluation.testing.cross_validation(self.learners, data, self.folds)
38 if isinstance(data.domain.class_var, Orange.feature.Discrete):
39 features = [Orange.feature.Continuous("%d" % i) for i in range(len(self.learners) * (len(data.domain.class_var.values) - 1))]
41 elif isinstance(data.domain.class_var, Orange.feature.Continuous):
42 features = [Orange.feature.Continuous("%d" % i) for i in range(len(self.learners))]
44 else:
45 raise RuntimeError("unknown class_var type")
47 domain = Orange.data.Domain(features + [data.domain.class_var])
48 p_data = Orange.data.Table(domain)
50 if isinstance(data.domain.class_var, Orange.feature.Discrete):
51 for r in res.results:
52 p_data.append([p for ps in r.probabilities for p in list(ps)[:-1]] + [r.actual_class])
53 else:
54 assert isinstance(data.domain.class_var, Orange.feature.Continuous)
56 for r in res.results:
57 p_data.append( r.classes + [r.actual_class])
59 assert len(p_data[0]) == len(domain)
61 meta_classifier = self.meta_learner(p_data)
62 classifiers = [l(data, weight) for l in self.learners]
64 #feature_domain = Orange.data.Domain(features)
65 return StackedClassifier(classifiers, meta_classifier, name=self.name, meta_domain=p_data.domain)
67 class StackedClassifier:
68 """
69 A classifier for stacking. Uses a set of level-0 classifiers to induce class probabilities, which
70 are an input to a meta-classifier to predict class probability for a given data instance.
72 :param classifiers: a list of level-0 classifiers.
73 :type classifiers: list
75 :param meta_classifier: meta-classifier.
76 :type meta_classifier: :class:`~Orange.classification.Classifier`
77 """
78 def __init__(self, classifiers, meta_classifier, meta_domain, **kwds):
79 self.classifiers = classifiers
80 self.meta_classifier = meta_classifier
81 self.meta_domain = meta_domain
82 self.domain = Orange.data.Domain(self.meta_domain.features, False)
83 self.__dict__.update(kwds)
85 def __call__(self, instance, resultType=Orange.core.GetValue):
86 if isinstance(self.meta_domain.class_var, Orange.feature.Discrete):
87 #if isinstance(self.meta_classifier.domain.class_var, Orange.feature.Discrete):
88 ps = Orange.data.Instance(self.domain, [p for cl in self.classifiers for p in list(cl(instance, Orange.core.GetProbabilities))[:-1]])
89 else:
90 assert isinstance(self.meta_domain.class_var, Orange.feature.Continuous)
91 #assert isinstance(self.meta_classifier.domain.class_var, Orange.feature.Continuous)
92 ps = Orange.data.Instance(self.domain, [ float(cl(instance, Orange.core.GetValue)) for cl in self.classifiers ])
94 return self.meta_classifier(ps, resultType)
98 ## tests and examples
101 def test_stack_reggression():
102 base_learners = [
103 Orange.regression.linear.LinearRegressionLearner(name='linear'),
104 Orange.regression.pls.PLSRegressionLearner(name='PLS'),
105 Orange.classification.knn.kNNLearner(k=20, name='knn 20'),
106 Orange.classification.knn.kNNLearner(k=30, name='knn 30')
107 #Orange.ensemble.forest.RandomForestLearner(name='random forrest')
110 stack = StackedClassificationLearner(base_learners,
111 #meta_learner=Orange.ensemble.forest.RandomForestLearner(name='meta random forrest'),
112 meta_learner=Orange.classification.knn.kNNLearner(k=20, name='meta knn 20'),
113 folds=10,
114 name='stacking')
116 learners = [ stack ] + base_learners
118 data = Orange.data.Table("housing")
119 res = Orange.evaluation.testing.cross_validation(learners, data, folds=10)
121 print "\n".join(["%8s: %5.3f" % (l.name, r) for r, l in zip(Orange.evaluation.scoring.RMSE(res), learners)])
123 def test_stack_classification():
124 data = Orange.data.Table("promoters")
126 bayes = Orange.classification.bayes.NaiveLearner(name="bayes")
127 tree = Orange.classification.tree.SimpleTreeLearner(name="tree")
128 lin = Orange.classification.svm.LinearLearner(name="lr")
129 knn = Orange.classification.knn.kNNLearner(name="knn")
131 base_learners = [bayes, tree, lin, knn]
132 stack = StackedClassificationLearner(base_learners)
134 learners = [stack, bayes, tree, lin, knn]
135 res = Orange.evaluation.testing.cross_validation(learners, data, 3)
136 print "\n".join(["%8s: %5.3f" % (l.name, r) for r, l in zip(Orange.evaluation.scoring.CA(res), learners)])
138 if __name__ == "__main__":
139 test_stack_reggression()
140 #test_stack_classification()