clanek: minor tweaks, citation fix and dinerchtein
[gostyle.git] / orange_hacks / knn_weighted.py
blobb8ae8f99b20d161bc88a5aca4fd6fa0c9ef5d948
1 import Orange
2 import numpy
3 import random
4 import math
5 import logging
7 class KnnWeightedLearner(Orange.classification.Learner):
8 def __new__(cls, examples=None, **kwargs):
9 learner = Orange.classification.Learner.__new__(cls, **kwargs)
10 if examples:
11 # force init and return classifier
12 learner.__init__(**kwargs)
13 return learner.__call__(examples)
14 else:
15 # invoke init
16 return learner
18 def __init__(self,
19 k=0,
20 alpha=1,
21 distance_constructor=Orange.distance.Euclidean(),
22 exp_weight=False,
23 name='knn weighted'):
24 self.k = k
25 self.alpha = alpha
26 self.distance_constructor = distance_constructor
27 self.name = name
28 self.exp_weight = exp_weight
30 def __call__(self, data, weight=0):
31 assert isinstance(data.domain.class_var, Orange.feature.Continuous)
33 if not data.domain.class_var:
34 raise ValueError('classless domain')
36 fnc = Orange.classification.knn.FindNearestConstructor()
37 fnc.distance_constructor = self.distance_constructor
38 did = Orange.feature.Descriptor.new_meta_id()
40 fn = fnc(data, 0, did)
42 k = self.k
43 if k == 0:
44 k = int(math.sqrt( len(data)))
46 return KnnWeightedClassifier(data.domain, k, fn, self.alpha, self.exp_weight)
48 ## FIXME Orange.classification.Classifier (which should be there)
49 ## is commented because if it is not, pickling does not work...
50 class KnnWeightedClassifier: #(Orange.classification.Classifier):
51 def __init__(self, domain, k, find_nearest, alpha, exp_weight):
52 self.domain = domain
53 self.domain_f = Orange.data.Domain(domain.features)
54 self.k = k
55 self.find_nearest = find_nearest
56 self.alpha = alpha
57 self.exp_weight = exp_weight
60 def __call__(self, instance, resultType=Orange.core.GetValue):
61 if not instance.domain != self.domain_f:
62 raise ValueError("instance has wrong domain")
64 def get_dist(nb):
65 return
67 nbs = self.find_nearest(instance, self.k)
69 # distances
70 dsts = numpy.array( [ nb[self.find_nearest.distance_ID]
71 for nb in nbs ])
72 # target variables
73 clss = numpy.array( [ nb.get_class()
74 for nb in nbs ])
75 if 0 in dsts:
76 #logging.warn('0 in distances, add epsilon')
77 dsts += 1e-5
79 # compute the weights
80 if not self.exp_weight:
81 # inversely proportional
82 w = dsts ** ( - self.alpha )
83 else:
84 assert 0.0 < self.alpha < 1.0
85 # weird exp.
86 w = self.alpha ** dsts
88 # normalize to 1
89 w = w / w.sum()
90 # lin combination
91 res = (w * clss).sum()
93 value = self.domain.class_var(res)
95 dist = Orange.statistics.distribution.Continuous(self.domain.class_var)
96 dist[value] = 1.
98 if resultType == Orange.core.GetValue:
99 return value
100 if resultType == Orange.core.GetProbabilities:
101 return dist
102 return (value, dist)
105 ## tests and examples
108 def test_housing():
109 data = Orange.data.Table("housing")
111 from fann_neural import FannNeuralLearner
113 learners = [
114 KnnWeightedLearner( k=4, alpha=2 ),
115 KnnWeightedLearner( k=4, alpha=1 ),
116 KnnWeightedLearner( k=4, alpha=0 ),
117 Orange.classification.knn.kNNLearner(k=4, name='knn 4'),
118 Orange.classification.knn.kNNLearner(k=4, name='knn 4, False', rank_weight=False),
121 cv = Orange.evaluation.testing.cross_validation(learners, data, folds=5)
123 for l, score in zip(learners, Orange.evaluation.scoring.RMSE(cv)):
124 print "%s: %.8f" % (l.name , score)
126 def plot_im() :
128 this is somewhat inspired by
129 http://quasiphysics.wordpress.com/2011/12/13/visualizing-k-nearest-neighbor-regression/
132 import Image, ImageDraw
134 attrs = [ Orange.feature.Continuous(name) for name in ['X', 'Y', 'color'] ]
135 insts = []
136 random.seed(50)
137 for num in xrange(10):
138 color = 255 * int(2 * random.random() )
140 def get_point():
141 return 0.25 + random.random() / 2
143 x, y = get_point(), get_point()
145 insts.append([x, y, color])
147 data = Orange.data.Table(Orange.data.Domain(attrs), insts)
149 def get_inst(a, b):
150 return Orange.data.Instance(Orange.data.Domain(data.domain.features),[a, b])
152 for k in xrange(1, 11):
153 for alpha in xrange(4):
154 for dist in [Orange.distance.Euclidean() ]: #, Orange.distance.Manhattan() ]:
156 l = KnnWeightedLearner( k=k, alpha=alpha, distance_constructor=dist)
157 #l = Orange.classification.knn.kNNLearner( k=k )
158 knn = l(data)
160 size = 200
162 a = []
163 for X in xrange(size):
164 for Y in xrange(size):
165 val = int(knn(get_inst(float(X)/size, float(Y)/size)))
166 a.append(val)
168 arr = numpy.array(a, dtype=numpy.uint8 )
169 arr = arr.reshape((size, size))
171 im = Image.fromarray(arr).convert("RGB")
172 for inst in data:
173 y, x = int(size * inst[0] ), int(size * inst[1])
174 color = int(inst[2])
176 draw = ImageDraw.Draw(im)
177 r = size / 50
178 draw.ellipse((x-r, y-r, x+r, y+r), outline=(255, 0, 0), fill=(color, color, color))
180 fn = "knn_w/k=%d_alpha=%d_dist=%s.ppm" % (k, alpha, dist.name)
181 print fn
182 im.save(fn)
183 #im.show()
185 if __name__ == "__main__":
186 #plot_im()
187 test_housing()
189 pass