ViewVC Help
View File | Revision Log | Show Annotations | View Changeset | Root Listing
root/osprai/osprai/trunk/curvefitting.py
Revision: 25
Committed: Wed Apr 28 20:22:24 2010 UTC (9 years, 4 months ago) by rjaynes
File size: 8430 byte(s)
Log Message:
Add py and obj files to allow modeling of more SPR experiments with converter and curvefitting modules.  This is the work of Yuhang Wan and Rui Hou.

1. In "converter.py": 
      Add the saving and reading function for the sprclass data object.
      Also add function "keyfile_read_fake" to provide default information for SPRit and ICM formats in case of the bug when do background_subtract.
      Fix the bugs in "background_subtract".
      Tested by DAM and ICM formats.
2. In model modules:
      "modelclass.py" is the parent class for all the other model classes that performs the theoretical simulating, loading and saving of the parameter or simulated data. Rui and I also add some other model modules like competing model, twostate model, parallel model, and the time variable concentrated models, where the simulated result is compared with Clamp's simulation to make sure the equations are correct. 
       The basicmodel and basicmodel_varyC class are tested. 
3. In "curvefitting.py":
      Add typical pipeline for operation. The examples are packed with the file. 
      Add function to show the Elapsed time for each fitting.
Line User Rev File contents
1 clausted 1 """
2 rjaynes 25 CurveFitting: Perform the curve fitting for the data object.
3 clausted 1
4 rjaynes 25 Yuhang Wan
5     Last modified on 100427 (yymmdd) by YW
6 clausted 1
7 rjaynes 25 Typical Pipeline:(work together with "converter.py" and some model class
8     here take the basicmodel as example )
9     >----------load data-----------
10     >import converter as cv
11     >dat = cv.datafile_read("dat_example.obj")
12     >----------load model----------
13     >import basicmodel as bm
14     >m = bm.basicmodel()
15     >m = m.load_model('model_example.obj')
16     >---------do the fitting-------
17     >import curvefitting as cf
18     >pfit, m2 = cf.fitting(dat, m)
19 clausted 1
20     """
21 rjaynes 25 __version__ = "100427"
22 clausted 1
23     import numpy as np
24     import pylab as plt
25     import copy
26     from scipy.optimize import leastsq
27 rjaynes 25 import time
28 clausted 1 ## import packageClamp_100225 as Pack
29     ## import SPRdataclass_100225 as SPR
30     ## import modelclass1_1 as Model
31    
32    
33     def initialize(data, parainfo):
34     global pfloat_name, pfloat0, pfix_name
35     # to seperate the parameters to fit and those fixed
36     pfloat_name, pfloat0, pfix_name, pfix = [], [], [], []
37     for i in parainfo:
38     if i['fixed'] == 1: # for those fixed parameters
39     pfix_name.append(i['name'])
40     pfix.append(i['value'])
41     elif i['fixed'] == 0: # for those float parameters
42     pfloat_name.append(i['name'])
43     pfloat0.append(i['value'])
44     else: print 'unreadable status.\n\
45     please check the parameter infomation!'
46    
47 rjaynes 25 print '--'*30
48 clausted 1 print 'The fixed parameters are:', pfix_name
49     print 'The fixed value:'
50     for i,j in enumerate(pfix_name):
51     print j,':',pfix[i]
52    
53 rjaynes 25 print '\n','--'*30
54 clausted 1 print 'The float parameters are:', pfloat_name
55     print 'The initial value:'
56     for i,j in enumerate(pfloat_name):
57     print j,':',pfloat0[i]
58    
59 rjaynes 25 print '\n','--'*30
60 clausted 1
61    
62     # format examination
63     # to check if the input parameter matches the data to fit
64     n_curve = len(data)/2
65     for i in parainfo:
66     ntmp = i['number']
67     if ntmp != 1:
68     if ntmp != n_curve:
69     print 'the number of input curves doesn\'t match that of the \
70     parameter\n\
71     Please check the data and the parameters!'
72     break
73     else:
74     value = i['value']
75     if ntmp != len(value):
76     print 'Please check the infomation of the parameters!'
77     break
78    
79     pfloat0_1D, pfloat_name_1D = resizePfloat(pfloat0, pfloat_name, 0)
80     return pfloat0_1D, pfix, pfloat_name_1D, pfix_name
81    
82     def resizePfloat(p, pname, flag = 0):
83     # resize the pfloat list
84     # convertion between the two kinds of format:
85     # 1D: [p_a,p_b,p_c1,p_c2,p_c3,p_d]
86     # mixed: [p_a,p_b,[p_c1,p_c2,p_c3],p_d]
87    
88     # first examine the input
89     if len(p) != len(pname):
90     print 'The parameter list doesn\'t match the name list, please check again!'
91    
92     new_p, new_pname = [], []
93    
94     if flag == 0:
95     # from mixed list to one dim list:
96     for pi in p:
97     if type(pi) == list:
98     for j in pi:
99     new_p.append(j)
100     new_pname.append(pname[p.index(pi)])
101     else:
102     new_p.append(pi)
103     new_pname.append(pname[p.index(pi)])
104    
105     elif flag == 1:
106     # from one dim list back to mixed list:
107     ptmp = []
108     for i,pni in enumerate(pname):
109     n = pname.count(pni)
110     if n == 1 :
111     new_p.append(p[pname.index(pni)])
112     new_pname.append(pni)
113     else:
114     if pni not in new_pname:
115     new_pname.append(pni)
116     ptmp = p[i:i+n]
117     ptmp = list(ptmp)
118     new_p.append(ptmp)
119    
120     return new_p, new_pname
121    
122     def createdict(p,pname):
123     '''Pack the name and value of all the parameters into a dictionary list.'''
124     pdict = []
125     if len(p) == len(pname):
126     for i,j in enumerate(pname):
127     dictmp = {}
128     dictmp['name'] = j
129     dictmp['value'] = p[i]
130     pdict.append(dictmp)
131     return pdict
132     else:
133     print 'check the input.\n'
134     print pname
135     print p
136     return
137    
138     ## def keepsinglevalue(paratmp, i):
139     ## for n,p in enumerate(paratmp):
140     ## if type(p['value'])==list:
141     ## paralist[n]['value'] = p['value'][i]
142     ## else:
143     ## paralist[n]['value'] = p['value']
144     ## return paralist
145    
146    
147     def residuals(pfloat_1D, data, pfix):
148     '''Error function.'''
149     # the input of the float parameters are in one dimension
150     e = 0
151     pfloat, pfloat_name = resizePfloat(pfloat_1D, pfloat_name_1D, 1)
152    
153     p_all, pname_all = pfloat + pfix, pfloat_name + pfix_name
154     paratmp = createdict(p_all, pname_all)
155     paralist = copy.deepcopy(paratmp)
156    
157     for i in range(len(data)/2):
158     for n,p in enumerate(paratmp):
159     if type(p['value'])==list:
160     paralist[n]['value'] = p['value'][i]
161     else:
162     paralist[n]['value'] = p['value']
163    
164     t = data[i*2]
165     y = data[i*2+1]
166     e = e + (sprfunction(t,paralist) - y)**2
167     return e
168    
169    
170    
171     def lmafit(data,pfix):
172     '''Fitting.
173     '''
174     # The initial values of the parameters need to be fit are converted
175     # into a one dimension list through "resizePfloat" function as required
176     # by the leastsq input.
177    
178     p_1D, success = leastsq(residuals, pfloat0_1D, args=(data,pfix), maxfev=10000)
179     p, pfloat_name = resizePfloat(p_1D, pfloat_name_1D, 1)
180     print 'Success:', success
181     plotfit(data, p_1D, pfix)
182    
183     return p
184    
185    
186     def plotfit(data, pfit_1D, pfix):
187     '''Plotting of real and fitted data.
188     Creating titles with ssq error and parameters.
189     Printing the result and fitted parameters.'''
190     global pfit_name
191     pfit, pfit_name = resizePfloat(pfit_1D, pfloat_name_1D, 1)
192     txt = 'SumSqE: %1.8f' % sum(np.square(residuals(pfit_1D, data, pfix)))
193     print txt
194     plt.title(txt)
195     plt.xlabel('Time (s)')
196     plt.ylabel('Response (uRIU)')
197     plt.grid(True)
198    
199     p_all, pname_all = pfit + pfix, pfit_name + pfix_name
200     paratmp = createdict(p_all, pname_all)
201     paralist = copy.deepcopy(paratmp)
202    
203     for i in range(len(data)/2):
204     for n,p in enumerate(paratmp):
205     if type(p['value'])==list:
206     paralist[n]['value'] = p['value'][i]
207     else:
208     paralist[n]['value'] = p['value']
209     t = data[i*2]
210     y = data[i*2+1]
211     plt.plot(t, y, ',') # Plot real data
212     plt.plot(t, sprfunction(t, paralist)) # Plot fitted data
213     plt.show()
214    
215 rjaynes 25 print '\n','-'*60
216 clausted 1 print 'The fitted parameters are:', pfit_name
217     print 'The fitted value:'
218     for i,j in enumerate(pfit_name):
219     print j,':',pfit[i]
220    
221     return
222    
223    
224     def updateparainfo(parainfo, pfit, pfit_name):
225     new_parainfo = copy.deepcopy(parainfo)
226 rjaynes 25
227     print '\n', '+'*25, "Comparison", '+'*25
228     ## print '-'*60
229     print 'name\t\t initial value \t\t fitted value'
230 clausted 1 for i in new_parainfo:
231     for n,j in enumerate(pfit_name):
232     if i['name'] == j:
233 rjaynes 25 print j, '\t\t ', i['value'], ' \t\t ', pfit[n]
234 clausted 1 i['value'] = pfit[n]
235    
236     return new_parainfo
237    
238    
239    
240     def fitting(dataobj, mobj):
241     '''The main function of curvefitting.
242     Input the data object and the model object.
243     '''
244     modelobj = copy.deepcopy(mobj)
245 rjaynes 25 time1 = time.time()
246 clausted 1 global pfloat0_1D, pfix, pfloat_name_1D, pfix_name, sprfunction
247     data = dataobj.data
248     parainfo = modelobj.parainfo
249     pfloat0_1D, pfix, pfloat_name_1D, pfix_name = initialize(data, parainfo)
250     sprfunction = modelobj.function
251    
252     pfit = lmafit(data, pfix)
253     parainfo = updateparainfo(parainfo, pfit, pfloat_name)
254 rjaynes 25
255     time2 = time.time()
256     print '-'*60
257     print "Time elapsed: %s seconds\n" %(time2-time1)
258 clausted 1 print 'Do you want to update the model with fitted value?'
259     if str.upper(raw_input('y/n : ')) == 'Y':
260     modelobj.updatemodel(parainfo)
261     print 'The model is updated.'
262    
263     return pfit, modelobj