ViewVC Help
View File | Revision Log | Show Annotations | View Changeset | Root Listing
root/osprai/osprai/trunk/curvefitting.py
Revision: 1
Committed: Wed Mar 17 05:34:43 2010 UTC (9 years, 6 months ago) by clausted
File size: 7738 byte(s)
Log Message:
Initial import of Osprai project
Line File contents
1
2 """
3 Perform the curve fitting for the data object.
4
5 Yuhang Wan, Feb25, 2010
6 __version__ = "1.0"
7
8
9 """
10
11
12 import numpy as np
13 import pylab as plt
14 import copy
15 from scipy.optimize import leastsq
16 ## import packageClamp_100225 as Pack
17 ## import SPRdataclass_100225 as SPR
18 ## import modelclass1_1 as Model
19
20
21 def initialize(data, parainfo):
22 global pfloat_name, pfloat0, pfix_name
23 # to seperate the parameters to fit and those fixed
24 pfloat_name, pfloat0, pfix_name, pfix = [], [], [], []
25 for i in parainfo:
26 if i['fixed'] == 1: # for those fixed parameters
27 pfix_name.append(i['name'])
28 pfix.append(i['value'])
29 elif i['fixed'] == 0: # for those float parameters
30 pfloat_name.append(i['name'])
31 pfloat0.append(i['value'])
32 else: print 'unreadable status.\n\
33 please check the parameter infomation!'
34
35 print '-'*50
36 print 'The fixed parameters are:', pfix_name
37 print 'The fixed value:'
38 for i,j in enumerate(pfix_name):
39 print j,':',pfix[i]
40
41 print '\n','-'*50
42 print 'The float parameters are:', pfloat_name
43 print 'The initial value:'
44 for i,j in enumerate(pfloat_name):
45 print j,':',pfloat0[i]
46
47 print '\n','-'*50
48
49
50 # format examination
51 # to check if the input parameter matches the data to fit
52 n_curve = len(data)/2
53 for i in parainfo:
54 ntmp = i['number']
55 if ntmp != 1:
56 if ntmp != n_curve:
57 print 'the number of input curves doesn\'t match that of the \
58 parameter\n\
59 Please check the data and the parameters!'
60 break
61 else:
62 value = i['value']
63 if ntmp != len(value):
64 print 'Please check the infomation of the parameters!'
65 break
66
67 pfloat0_1D, pfloat_name_1D = resizePfloat(pfloat0, pfloat_name, 0)
68 return pfloat0_1D, pfix, pfloat_name_1D, pfix_name
69
70 def resizePfloat(p, pname, flag = 0):
71 # resize the pfloat list
72 # convertion between the two kinds of format:
73 # 1D: [p_a,p_b,p_c1,p_c2,p_c3,p_d]
74 # mixed: [p_a,p_b,[p_c1,p_c2,p_c3],p_d]
75
76 # first examine the input
77 if len(p) != len(pname):
78 print 'The parameter list doesn\'t match the name list, please check again!'
79
80 new_p, new_pname = [], []
81
82 if flag == 0:
83 # from mixed list to one dim list:
84 for pi in p:
85 if type(pi) == list:
86 for j in pi:
87 new_p.append(j)
88 new_pname.append(pname[p.index(pi)])
89 else:
90 new_p.append(pi)
91 new_pname.append(pname[p.index(pi)])
92
93 elif flag == 1:
94 # from one dim list back to mixed list:
95 ptmp = []
96 for i,pni in enumerate(pname):
97 n = pname.count(pni)
98 if n == 1 :
99 new_p.append(p[pname.index(pni)])
100 new_pname.append(pni)
101 else:
102 if pni not in new_pname:
103 new_pname.append(pni)
104 ptmp = p[i:i+n]
105 ptmp = list(ptmp)
106 new_p.append(ptmp)
107
108 return new_p, new_pname
109
110 def createdict(p,pname):
111 '''Pack the name and value of all the parameters into a dictionary list.'''
112 pdict = []
113 if len(p) == len(pname):
114 for i,j in enumerate(pname):
115 dictmp = {}
116 dictmp['name'] = j
117 dictmp['value'] = p[i]
118 pdict.append(dictmp)
119 return pdict
120 else:
121 print 'check the input.\n'
122 print pname
123 print p
124 return
125
126 ## def keepsinglevalue(paratmp, i):
127 ## for n,p in enumerate(paratmp):
128 ## if type(p['value'])==list:
129 ## paralist[n]['value'] = p['value'][i]
130 ## else:
131 ## paralist[n]['value'] = p['value']
132 ## return paralist
133
134
135 def residuals(pfloat_1D, data, pfix):
136 '''Error function.'''
137 # the input of the float parameters are in one dimension
138 e = 0
139 pfloat, pfloat_name = resizePfloat(pfloat_1D, pfloat_name_1D, 1)
140
141 p_all, pname_all = pfloat + pfix, pfloat_name + pfix_name
142 paratmp = createdict(p_all, pname_all)
143 paralist = copy.deepcopy(paratmp)
144
145 for i in range(len(data)/2):
146 for n,p in enumerate(paratmp):
147 if type(p['value'])==list:
148 paralist[n]['value'] = p['value'][i]
149 else:
150 paralist[n]['value'] = p['value']
151
152 t = data[i*2]
153 y = data[i*2+1]
154 e = e + (sprfunction(t,paralist) - y)**2
155 return e
156
157
158
159 def lmafit(data,pfix):
160 '''Fitting.
161 '''
162 # The initial values of the parameters need to be fit are converted
163 # into a one dimension list through "resizePfloat" function as required
164 # by the leastsq input.
165
166 p_1D, success = leastsq(residuals, pfloat0_1D, args=(data,pfix), maxfev=10000)
167 p, pfloat_name = resizePfloat(p_1D, pfloat_name_1D, 1)
168 print 'Success:', success
169 plotfit(data, p_1D, pfix)
170
171 return p
172
173
174 def plotfit(data, pfit_1D, pfix):
175 '''Plotting of real and fitted data.
176 Creating titles with ssq error and parameters.
177 Printing the result and fitted parameters.'''
178 global pfit_name
179 pfit, pfit_name = resizePfloat(pfit_1D, pfloat_name_1D, 1)
180 txt = 'SumSqE: %1.8f' % sum(np.square(residuals(pfit_1D, data, pfix)))
181 print txt
182 plt.title(txt)
183 plt.xlabel('Time (s)')
184 plt.ylabel('Response (uRIU)')
185 plt.grid(True)
186
187 p_all, pname_all = pfit + pfix, pfit_name + pfix_name
188 paratmp = createdict(p_all, pname_all)
189 paralist = copy.deepcopy(paratmp)
190
191 for i in range(len(data)/2):
192 for n,p in enumerate(paratmp):
193 if type(p['value'])==list:
194 paralist[n]['value'] = p['value'][i]
195 else:
196 paralist[n]['value'] = p['value']
197 t = data[i*2]
198 y = data[i*2+1]
199 plt.plot(t, y, ',') # Plot real data
200 plt.plot(t, sprfunction(t, paralist)) # Plot fitted data
201 plt.show()
202
203
204 print '\n','-'*50
205 print 'The fitted parameters are:', pfit_name
206 print 'The fitted value:'
207 for i,j in enumerate(pfit_name):
208 print j,':',pfit[i]
209
210 return
211
212
213 def updateparainfo(parainfo, pfit, pfit_name):
214 new_parainfo = copy.deepcopy(parainfo)
215 print '-'*50
216 print 'name\t initial value \t fitted value'
217 for i in new_parainfo:
218 for n,j in enumerate(pfit_name):
219 if i['name'] == j:
220 print j, '\t', i['value'], '\t', pfit[n]
221 i['value'] = pfit[n]
222
223 return new_parainfo
224
225
226
227 def fitting(dataobj, mobj):
228 '''The main function of curvefitting.
229 Input the data object and the model object.
230 '''
231 modelobj = copy.deepcopy(mobj)
232 global pfloat0_1D, pfix, pfloat_name_1D, pfix_name, sprfunction
233 data = dataobj.data
234 parainfo = modelobj.parainfo
235 pfloat0_1D, pfix, pfloat_name_1D, pfix_name = initialize(data, parainfo)
236 sprfunction = modelobj.function
237
238 pfit = lmafit(data, pfix)
239 parainfo = updateparainfo(parainfo, pfit, pfloat_name)
240 print 'Do you want to update the model with fitted value?'
241 if str.upper(raw_input('y/n : ')) == 'Y':
242 modelobj.updatemodel(parainfo)
243 print 'The model is updated.'
244
245 return pfit, modelobj