ViewVC Help
View File | Revision Log | Show Annotations | View Changeset | Root Listing
root/osprai/osprai/trunk/curvefitting.py
Revision: 25
Committed: Wed Apr 28 20:22:24 2010 UTC (8 years, 11 months ago) by rjaynes
File size: 8430 byte(s)
Log Message:
Add py and obj files to allow modeling of more SPR experiments with converter and curvefitting modules.  This is the work of Yuhang Wan and Rui Hou.

1. In "converter.py": 
      Add the saving and reading function for the sprclass data object.
      Also add function "keyfile_read_fake" to provide default information for SPRit and ICM formats in case of the bug when do background_subtract.
      Fix the bugs in "background_subtract".
      Tested by DAM and ICM formats.
2. In model modules:
      "modelclass.py" is the parent class for all the other model classes that performs the theoretical simulating, loading and saving of the parameter or simulated data. Rui and I also add some other model modules like competing model, twostate model, parallel model, and the time variable concentrated models, where the simulated result is compared with Clamp's simulation to make sure the equations are correct. 
       The basicmodel and basicmodel_varyC class are tested. 
3. In "curvefitting.py":
      Add typical pipeline for operation. The examples are packed with the file. 
      Add function to show the Elapsed time for each fitting.
Line File contents
1 """
2 CurveFitting: Perform the curve fitting for the data object.
3
4 Yuhang Wan
5 Last modified on 100427 (yymmdd) by YW
6
7 Typical Pipeline:(work together with "converter.py" and some model class
8 here take the basicmodel as example )
9 >----------load data-----------
10 >import converter as cv
11 >dat = cv.datafile_read("dat_example.obj")
12 >----------load model----------
13 >import basicmodel as bm
14 >m = bm.basicmodel()
15 >m = m.load_model('model_example.obj')
16 >---------do the fitting-------
17 >import curvefitting as cf
18 >pfit, m2 = cf.fitting(dat, m)
19
20 """
21 __version__ = "100427"
22
23 import numpy as np
24 import pylab as plt
25 import copy
26 from scipy.optimize import leastsq
27 import time
28 ## import packageClamp_100225 as Pack
29 ## import SPRdataclass_100225 as SPR
30 ## import modelclass1_1 as Model
31
32
33 def initialize(data, parainfo):
34 global pfloat_name, pfloat0, pfix_name
35 # to seperate the parameters to fit and those fixed
36 pfloat_name, pfloat0, pfix_name, pfix = [], [], [], []
37 for i in parainfo:
38 if i['fixed'] == 1: # for those fixed parameters
39 pfix_name.append(i['name'])
40 pfix.append(i['value'])
41 elif i['fixed'] == 0: # for those float parameters
42 pfloat_name.append(i['name'])
43 pfloat0.append(i['value'])
44 else: print 'unreadable status.\n\
45 please check the parameter infomation!'
46
47 print '--'*30
48 print 'The fixed parameters are:', pfix_name
49 print 'The fixed value:'
50 for i,j in enumerate(pfix_name):
51 print j,':',pfix[i]
52
53 print '\n','--'*30
54 print 'The float parameters are:', pfloat_name
55 print 'The initial value:'
56 for i,j in enumerate(pfloat_name):
57 print j,':',pfloat0[i]
58
59 print '\n','--'*30
60
61
62 # format examination
63 # to check if the input parameter matches the data to fit
64 n_curve = len(data)/2
65 for i in parainfo:
66 ntmp = i['number']
67 if ntmp != 1:
68 if ntmp != n_curve:
69 print 'the number of input curves doesn\'t match that of the \
70 parameter\n\
71 Please check the data and the parameters!'
72 break
73 else:
74 value = i['value']
75 if ntmp != len(value):
76 print 'Please check the infomation of the parameters!'
77 break
78
79 pfloat0_1D, pfloat_name_1D = resizePfloat(pfloat0, pfloat_name, 0)
80 return pfloat0_1D, pfix, pfloat_name_1D, pfix_name
81
82 def resizePfloat(p, pname, flag = 0):
83 # resize the pfloat list
84 # convertion between the two kinds of format:
85 # 1D: [p_a,p_b,p_c1,p_c2,p_c3,p_d]
86 # mixed: [p_a,p_b,[p_c1,p_c2,p_c3],p_d]
87
88 # first examine the input
89 if len(p) != len(pname):
90 print 'The parameter list doesn\'t match the name list, please check again!'
91
92 new_p, new_pname = [], []
93
94 if flag == 0:
95 # from mixed list to one dim list:
96 for pi in p:
97 if type(pi) == list:
98 for j in pi:
99 new_p.append(j)
100 new_pname.append(pname[p.index(pi)])
101 else:
102 new_p.append(pi)
103 new_pname.append(pname[p.index(pi)])
104
105 elif flag == 1:
106 # from one dim list back to mixed list:
107 ptmp = []
108 for i,pni in enumerate(pname):
109 n = pname.count(pni)
110 if n == 1 :
111 new_p.append(p[pname.index(pni)])
112 new_pname.append(pni)
113 else:
114 if pni not in new_pname:
115 new_pname.append(pni)
116 ptmp = p[i:i+n]
117 ptmp = list(ptmp)
118 new_p.append(ptmp)
119
120 return new_p, new_pname
121
122 def createdict(p,pname):
123 '''Pack the name and value of all the parameters into a dictionary list.'''
124 pdict = []
125 if len(p) == len(pname):
126 for i,j in enumerate(pname):
127 dictmp = {}
128 dictmp['name'] = j
129 dictmp['value'] = p[i]
130 pdict.append(dictmp)
131 return pdict
132 else:
133 print 'check the input.\n'
134 print pname
135 print p
136 return
137
138 ## def keepsinglevalue(paratmp, i):
139 ## for n,p in enumerate(paratmp):
140 ## if type(p['value'])==list:
141 ## paralist[n]['value'] = p['value'][i]
142 ## else:
143 ## paralist[n]['value'] = p['value']
144 ## return paralist
145
146
147 def residuals(pfloat_1D, data, pfix):
148 '''Error function.'''
149 # the input of the float parameters are in one dimension
150 e = 0
151 pfloat, pfloat_name = resizePfloat(pfloat_1D, pfloat_name_1D, 1)
152
153 p_all, pname_all = pfloat + pfix, pfloat_name + pfix_name
154 paratmp = createdict(p_all, pname_all)
155 paralist = copy.deepcopy(paratmp)
156
157 for i in range(len(data)/2):
158 for n,p in enumerate(paratmp):
159 if type(p['value'])==list:
160 paralist[n]['value'] = p['value'][i]
161 else:
162 paralist[n]['value'] = p['value']
163
164 t = data[i*2]
165 y = data[i*2+1]
166 e = e + (sprfunction(t,paralist) - y)**2
167 return e
168
169
170
171 def lmafit(data,pfix):
172 '''Fitting.
173 '''
174 # The initial values of the parameters need to be fit are converted
175 # into a one dimension list through "resizePfloat" function as required
176 # by the leastsq input.
177
178 p_1D, success = leastsq(residuals, pfloat0_1D, args=(data,pfix), maxfev=10000)
179 p, pfloat_name = resizePfloat(p_1D, pfloat_name_1D, 1)
180 print 'Success:', success
181 plotfit(data, p_1D, pfix)
182
183 return p
184
185
186 def plotfit(data, pfit_1D, pfix):
187 '''Plotting of real and fitted data.
188 Creating titles with ssq error and parameters.
189 Printing the result and fitted parameters.'''
190 global pfit_name
191 pfit, pfit_name = resizePfloat(pfit_1D, pfloat_name_1D, 1)
192 txt = 'SumSqE: %1.8f' % sum(np.square(residuals(pfit_1D, data, pfix)))
193 print txt
194 plt.title(txt)
195 plt.xlabel('Time (s)')
196 plt.ylabel('Response (uRIU)')
197 plt.grid(True)
198
199 p_all, pname_all = pfit + pfix, pfit_name + pfix_name
200 paratmp = createdict(p_all, pname_all)
201 paralist = copy.deepcopy(paratmp)
202
203 for i in range(len(data)/2):
204 for n,p in enumerate(paratmp):
205 if type(p['value'])==list:
206 paralist[n]['value'] = p['value'][i]
207 else:
208 paralist[n]['value'] = p['value']
209 t = data[i*2]
210 y = data[i*2+1]
211 plt.plot(t, y, ',') # Plot real data
212 plt.plot(t, sprfunction(t, paralist)) # Plot fitted data
213 plt.show()
214
215 print '\n','-'*60
216 print 'The fitted parameters are:', pfit_name
217 print 'The fitted value:'
218 for i,j in enumerate(pfit_name):
219 print j,':',pfit[i]
220
221 return
222
223
224 def updateparainfo(parainfo, pfit, pfit_name):
225 new_parainfo = copy.deepcopy(parainfo)
226
227 print '\n', '+'*25, "Comparison", '+'*25
228 ## print '-'*60
229 print 'name\t\t initial value \t\t fitted value'
230 for i in new_parainfo:
231 for n,j in enumerate(pfit_name):
232 if i['name'] == j:
233 print j, '\t\t ', i['value'], ' \t\t ', pfit[n]
234 i['value'] = pfit[n]
235
236 return new_parainfo
237
238
239
240 def fitting(dataobj, mobj):
241 '''The main function of curvefitting.
242 Input the data object and the model object.
243 '''
244 modelobj = copy.deepcopy(mobj)
245 time1 = time.time()
246 global pfloat0_1D, pfix, pfloat_name_1D, pfix_name, sprfunction
247 data = dataobj.data
248 parainfo = modelobj.parainfo
249 pfloat0_1D, pfix, pfloat_name_1D, pfix_name = initialize(data, parainfo)
250 sprfunction = modelobj.function
251
252 pfit = lmafit(data, pfix)
253 parainfo = updateparainfo(parainfo, pfit, pfloat_name)
254
255 time2 = time.time()
256 print '-'*60
257 print "Time elapsed: %s seconds\n" %(time2-time1)
258 print 'Do you want to update the model with fitted value?'
259 if str.upper(raw_input('y/n : ')) == 'Y':
260 modelobj.updatemodel(parainfo)
261 print 'The model is updated.'
262
263 return pfit, modelobj