adding requests to conda install docs
[JPSSData.git] / synthetic.py
bloba4b68c152267e3c53dfe59e450977e74e5226dc3
1 import numpy as np
2 import matplotlib.pyplot as plt
3 from mpl_toolkits.mplot3d import axes3d
4 import matplotlib.colors as colors
5 from svm import SVM3
6 from scipy.io import savemat
7 from scipy import interpolate
9 def plot_case(xx,yy,tign_g,X_satellite=None,show=False):
10 fig = plt.figure()
11 ax = fig.gca(projection='3d')
12 ax.contour(xx,yy,tign_g,30)
13 if X_satellite is not None:
14 ax.scatter(X_satellite[:,0],X_satellite[:,1],
15 X_satellite[:,2],s=5,color='r')
16 ax.set_xlabel("X")
17 ax.set_ylabel("Y")
18 ax.set_zlabel("T")
19 if show:
20 plt.show()
21 else:
22 plt.savefig('syn_case.png')
24 def plot_data(X,y,show=False):
25 col = [(0, .5, 0), (.5, 0, 0)]
26 cm_GR = colors.LinearSegmentedColormap.from_list('GrRd',col,N=2)
27 fig = plt.figure()
28 ax = fig.gca(projection='3d')
29 ax.scatter(X[:, 0], X[:, 1], X[:, 2],
30 c=y, cmap=cm_GR, s=1, alpha=.5,
31 vmin=y.min(), vmax=y.max())
32 ax.set_xlabel("X")
33 ax.set_ylabel("Y")
34 ax.set_zlabel("T")
35 if show:
36 plt.show()
37 else:
38 plt.savefig('syn_data.png')
41 def cone_point(xx,yy,nx,ny):
42 cx = nx*.5
43 cy = ny*.5
44 tign_g = np.minimum(1e3,10+(2e3/cx)*np.sqrt(((xx-cx)**2+(yy-cy)**2)/2))
45 tsat = (tign_g.max()-tign_g.min())*.5
46 tt1d = np.ravel(tign_g)
47 mask = tt1d < tt1d.max()
48 xx1d = np.ravel(xx)[mask]
49 yy1d = np.ravel(yy)[mask]
50 tt1d = tt1d[mask]
51 X_satellite = np.array([[cx*.7,cy*.7,tsat]])
52 return tign_g,xx1d,yy1d,tt1d,X_satellite
54 def cone_points(xx,yy,nx,ny):
55 cx = nx*.5
56 cy = ny*.5
57 tign_g = np.minimum(1e3,10+(2e3/cx)*np.sqrt(((xx-cx)**2+(yy-cy)**2)/2))
58 tsat = (tign_g.max()-tign_g.min())*.5
59 tt1d = np.ravel(tign_g)
60 mask = tt1d < tt1d.max()
61 xx1d = np.ravel(xx)[mask]
62 yy1d = np.ravel(yy)[mask]
63 tt1d = tt1d[mask]
64 N = 10
65 X_satellite = np.c_[np.linspace(cx*.7,cx,N+1),
66 np.linspace(cy*.7,cy,N+1),
67 np.linspace(tsat,tign_g.min(),N+1)][:-1]
68 return tign_g,xx1d,yy1d,tt1d,X_satellite
70 def slope(xx,yy,nx,ny):
71 ros = (10,30) # rate of spread
72 cx = round(nx*.5)
73 s1 = 10+np.arange(0,cx*ros[0],ros[0])
74 s2 = ros[1]+np.arange(cx*ros[0],cx*ros[0]+(nx-cx)*ros[1],ros[1])
75 s = np.concatenate((s1,s2))
76 tign_g = np.reshape(np.repeat(s,ny),(nx,ny)).T
77 xx1d = np.ravel(xx)
78 yy1d = np.ravel(yy)
79 tt1d = np.ravel(tign_g)
80 X_satellite = None
81 return tign_g,xx1d,yy1d,tt1d,X_satellite
83 def preprocess_svm(xx,yy,tt,epsilon,weights,X_satellite=None):
84 wforecastg,wforecastf,wsatellite = weights
85 for_fire = np.c_[xx.ravel(),yy.ravel(),tt.ravel() + epsilon]
86 for_nofire = np.c_[xx.ravel(),yy.ravel(),tt.ravel() - epsilon]
87 X_forecast = np.concatenate((for_nofire,for_fire))
88 y_forecast = np.concatenate((-np.ones(len(for_nofire)),np.ones(len(for_fire))))
89 c_forecast = np.concatenate((wforecastg*np.ones(len(for_nofire)),wforecastf*np.ones(len(for_fire))))
90 if X_satellite is not None:
91 X = np.concatenate((X_forecast,X_satellite))
92 y = np.concatenate((y_forecast,np.ones(len(X_satellite))))
93 c = np.concatenate((c_forecast,wsatellite*np.ones(len(X_satellite))))
94 else:
95 X = X_forecast
96 y = y_forecast
97 c = c_forecast
98 return X,y,c
100 if __name__ == "__main__":
101 ## SETTINGS
102 # Experiments: 1) Cone with point, 2) Slope, 3) Cone with points
103 exp = 2
104 # hyperparameter settings
105 wforecastg = 50
106 wforecastf = 50
107 wsatellite = 50
108 kgam = 1
109 # epsilon for artificial forecast in seconds
110 epsilon = 1
111 # dimensions
112 nx, ny = 50, 50
113 # plotting data before svm?
114 plot = True
116 ## CASE
117 xx,yy = np.meshgrid(np.arange(0,nx,1),
118 np.arange(0,ny,1))
119 # select experiment
120 experiments = {1: cone_point, 2: slope, 3: cone_points}
121 tign_g,xx1d,yy1d,tt1d,X_satellite = experiments[exp](xx,yy,nx,ny)
122 if plot:
123 plot_case(xx,yy,tign_g,X_satellite)
125 ## PREPROCESS
126 if X_satellite is None:
127 wsatellite = 0
128 X,y,c = preprocess_svm(xx1d,yy1d,tt1d,epsilon,
129 (wforecastg,wforecastf,wsatellite),
130 X_satellite)
131 if plot:
132 plot_data(X,y)
134 ## SVM
135 # options for SVM
136 options = {'downarti': False, 'plot_data': True,
137 'plot_scaled': True, 'plot_supports': True,
138 'plot_result': True, 'plot_decision': True,
139 'artiu': False, 'hartiu': .2,
140 'artil': False, 'hartil': .2,
141 'notnan': True}
142 if (wforecastg == wforecastf and
143 (wsatellite == 0 or wsatellite == wforecastg)):
144 c = wforecastg
145 # running SVM
146 F = SVM3(X, y, C=c, kgam=kgam, **options)
148 ## POSTPROCESS
149 # interpolation to validate
150 points = np.c_[np.ravel(F[0]),np.ravel(F[1])]
151 values = np.ravel(F[2])
152 zz_svm = interpolate.griddata(points,values,(xx,yy))
153 # output dictionary
154 svm = {'xx': xx, 'yy': yy, 'zz': tign_g,
155 'zz_svm': zz_svm, 'X': X, 'y': y, 'c': c,
156 'fxlon': F[0], 'fxlat': F[1], 'Z': F[2],
157 'epsilon': epsilon, 'options': options}
158 # output file
159 if wsatellite:
160 filename = 'syn_fg%d_ff%d_s%d_k%d_e%d.mat' % (wforecastg,wforecastf,
161 wsatellite,kgam,epsilon)
162 else:
163 filename = 'syn_fg%d_ff%d_k%d_e%d.mat' % (wforecastg,wforecastf,
164 kgam,epsilon)
165 savemat(filename, mdict=svm)
166 print 'plot_svm %s' % filename