ViewVC Help
View File | Revision Log | Show Annotations | View Changeset | Root Listing
root/osprai/osprai/trunk/models2010/curvefitting.py
Revision: 41
Committed: Tue Jan 18 00:35:23 2011 UTC (8 years, 4 months ago) by clausted
File size: 8430 byte(s)
Log Message:
Moved old data class "SPRdataclass" and accompanying surface interaction model modules to /models2010 subdirectory.  The plan is to implement these models for use with the "ba_class" and the modules in the parent directory.  

Should all the models be added to mdl_module or should they each go in their own module?  I am undecided.  
Line File contents
1 """
2 CurveFitting: Perform the curve fitting for the data object.
3
4 Yuhang Wan
5 Last modified on 100427 (yymmdd) by YW
6
7 Typical Pipeline:(work together with "converter.py" and some model class
8 here take the basicmodel as example )
9 >----------load data-----------
10 >import converter as cv
11 >dat = cv.datafile_read("dat_example.obj")
12 >----------load model----------
13 >import basicmodel as bm
14 >m = bm.basicmodel()
15 >m = m.load_model('model_example.obj')
16 >---------do the fitting-------
17 >import curvefitting as cf
18 >pfit, m2 = cf.fitting(dat, m)
19
20 """
21 __version__ = "100427"
22
23 import numpy as np
24 import pylab as plt
25 import copy
26 from scipy.optimize import leastsq
27 import time
28 ## import packageClamp_100225 as Pack
29 ## import SPRdataclass_100225 as SPR
30 ## import modelclass1_1 as Model
31
32
33 def initialize(data, parainfo):
34 global pfloat_name, pfloat0, pfix_name
35 # to seperate the parameters to fit and those fixed
36 pfloat_name, pfloat0, pfix_name, pfix = [], [], [], []
37 for i in parainfo:
38 if i['fixed'] == 1: # for those fixed parameters
39 pfix_name.append(i['name'])
40 pfix.append(i['value'])
41 elif i['fixed'] == 0: # for those float parameters
42 pfloat_name.append(i['name'])
43 pfloat0.append(i['value'])
44 else: print 'unreadable status.\n\
45 please check the parameter infomation!'
46
47 print '--'*30
48 print 'The fixed parameters are:', pfix_name
49 print 'The fixed value:'
50 for i,j in enumerate(pfix_name):
51 print j,':',pfix[i]
52
53 print '\n','--'*30
54 print 'The float parameters are:', pfloat_name
55 print 'The initial value:'
56 for i,j in enumerate(pfloat_name):
57 print j,':',pfloat0[i]
58
59 print '\n','--'*30
60
61
62 # format examination
63 # to check if the input parameter matches the data to fit
64 n_curve = len(data)/2
65 for i in parainfo:
66 ntmp = i['number']
67 if ntmp != 1:
68 if ntmp != n_curve:
69 print 'the number of input curves doesn\'t match that of the \
70 parameter\n\
71 Please check the data and the parameters!'
72 break
73 else:
74 value = i['value']
75 if ntmp != len(value):
76 print 'Please check the infomation of the parameters!'
77 break
78
79 pfloat0_1D, pfloat_name_1D = resizePfloat(pfloat0, pfloat_name, 0)
80 return pfloat0_1D, pfix, pfloat_name_1D, pfix_name
81
82 def resizePfloat(p, pname, flag = 0):
83 # resize the pfloat list
84 # convertion between the two kinds of format:
85 # 1D: [p_a,p_b,p_c1,p_c2,p_c3,p_d]
86 # mixed: [p_a,p_b,[p_c1,p_c2,p_c3],p_d]
87
88 # first examine the input
89 if len(p) != len(pname):
90 print 'The parameter list doesn\'t match the name list, please check again!'
91
92 new_p, new_pname = [], []
93
94 if flag == 0:
95 # from mixed list to one dim list:
96 for pi in p:
97 if type(pi) == list:
98 for j in pi:
99 new_p.append(j)
100 new_pname.append(pname[p.index(pi)])
101 else:
102 new_p.append(pi)
103 new_pname.append(pname[p.index(pi)])
104
105 elif flag == 1:
106 # from one dim list back to mixed list:
107 ptmp = []
108 for i,pni in enumerate(pname):
109 n = pname.count(pni)
110 if n == 1 :
111 new_p.append(p[pname.index(pni)])
112 new_pname.append(pni)
113 else:
114 if pni not in new_pname:
115 new_pname.append(pni)
116 ptmp = p[i:i+n]
117 ptmp = list(ptmp)
118 new_p.append(ptmp)
119
120 return new_p, new_pname
121
122 def createdict(p,pname):
123 '''Pack the name and value of all the parameters into a dictionary list.'''
124 pdict = []
125 if len(p) == len(pname):
126 for i,j in enumerate(pname):
127 dictmp = {}
128 dictmp['name'] = j
129 dictmp['value'] = p[i]
130 pdict.append(dictmp)
131 return pdict
132 else:
133 print 'check the input.\n'
134 print pname
135 print p
136 return
137
138 ## def keepsinglevalue(paratmp, i):
139 ## for n,p in enumerate(paratmp):
140 ## if type(p['value'])==list:
141 ## paralist[n]['value'] = p['value'][i]
142 ## else:
143 ## paralist[n]['value'] = p['value']
144 ## return paralist
145
146
147 def residuals(pfloat_1D, data, pfix):
148 '''Error function.'''
149 # the input of the float parameters are in one dimension
150 e = 0
151 pfloat, pfloat_name = resizePfloat(pfloat_1D, pfloat_name_1D, 1)
152
153 p_all, pname_all = pfloat + pfix, pfloat_name + pfix_name
154 paratmp = createdict(p_all, pname_all)
155 paralist = copy.deepcopy(paratmp)
156
157 for i in range(len(data)/2):
158 for n,p in enumerate(paratmp):
159 if type(p['value'])==list:
160 paralist[n]['value'] = p['value'][i]
161 else:
162 paralist[n]['value'] = p['value']
163
164 t = data[i*2]
165 y = data[i*2+1]
166 e = e + (sprfunction(t,paralist) - y)**2
167 return e
168
169
170
171 def lmafit(data,pfix):
172 '''Fitting.
173 '''
174 # The initial values of the parameters need to be fit are converted
175 # into a one dimension list through "resizePfloat" function as required
176 # by the leastsq input.
177
178 p_1D, success = leastsq(residuals, pfloat0_1D, args=(data,pfix), maxfev=10000)
179 p, pfloat_name = resizePfloat(p_1D, pfloat_name_1D, 1)
180 print 'Success:', success
181 plotfit(data, p_1D, pfix)
182
183 return p
184
185
186 def plotfit(data, pfit_1D, pfix):
187 '''Plotting of real and fitted data.
188 Creating titles with ssq error and parameters.
189 Printing the result and fitted parameters.'''
190 global pfit_name
191 pfit, pfit_name = resizePfloat(pfit_1D, pfloat_name_1D, 1)
192 txt = 'SumSqE: %1.8f' % sum(np.square(residuals(pfit_1D, data, pfix)))
193 print txt
194 plt.title(txt)
195 plt.xlabel('Time (s)')
196 plt.ylabel('Response (uRIU)')
197 plt.grid(True)
198
199 p_all, pname_all = pfit + pfix, pfit_name + pfix_name
200 paratmp = createdict(p_all, pname_all)
201 paralist = copy.deepcopy(paratmp)
202
203 for i in range(len(data)/2):
204 for n,p in enumerate(paratmp):
205 if type(p['value'])==list:
206 paralist[n]['value'] = p['value'][i]
207 else:
208 paralist[n]['value'] = p['value']
209 t = data[i*2]
210 y = data[i*2+1]
211 plt.plot(t, y, ',') # Plot real data
212 plt.plot(t, sprfunction(t, paralist)) # Plot fitted data
213 plt.show()
214
215 print '\n','-'*60
216 print 'The fitted parameters are:', pfit_name
217 print 'The fitted value:'
218 for i,j in enumerate(pfit_name):
219 print j,':',pfit[i]
220
221 return
222
223
224 def updateparainfo(parainfo, pfit, pfit_name):
225 new_parainfo = copy.deepcopy(parainfo)
226
227 print '\n', '+'*25, "Comparison", '+'*25
228 ## print '-'*60
229 print 'name\t\t initial value \t\t fitted value'
230 for i in new_parainfo:
231 for n,j in enumerate(pfit_name):
232 if i['name'] == j:
233 print j, '\t\t ', i['value'], ' \t\t ', pfit[n]
234 i['value'] = pfit[n]
235
236 return new_parainfo
237
238
239
240 def fitting(dataobj, mobj):
241 '''The main function of curvefitting.
242 Input the data object and the model object.
243 '''
244 modelobj = copy.deepcopy(mobj)
245 time1 = time.time()
246 global pfloat0_1D, pfix, pfloat_name_1D, pfix_name, sprfunction
247 data = dataobj.data
248 parainfo = modelobj.parainfo
249 pfloat0_1D, pfix, pfloat_name_1D, pfix_name = initialize(data, parainfo)
250 sprfunction = modelobj.function
251
252 pfit = lmafit(data, pfix)
253 parainfo = updateparainfo(parainfo, pfit, pfloat_name)
254
255 time2 = time.time()
256 print '-'*60
257 print "Time elapsed: %s seconds\n" %(time2-time1)
258 print 'Do you want to update the model with fitted value?'
259 if str.upper(raw_input('y/n : ')) == 'Y':
260 modelobj.updatemodel(parainfo)
261 print 'The model is updated.'
262
263 return pfit, modelobj