ViewVC Help
View File | Revision Log | Show Annotations | View Changeset | Root Listing
root/osprai/osprai/trunk/curvefitting.py
Revision: 1
Committed: Wed Mar 17 05:34:43 2010 UTC (9 years, 6 months ago) by clausted
File size: 7738 byte(s)
Log Message:
Initial import of Osprai project
Line User Rev File contents
1 clausted 1
2     """
3     Perform the curve fitting for the data object.
4    
5     Yuhang Wan, Feb25, 2010
6     __version__ = "1.0"
7    
8    
9     """
10    
11    
12     import numpy as np
13     import pylab as plt
14     import copy
15     from scipy.optimize import leastsq
16     ## import packageClamp_100225 as Pack
17     ## import SPRdataclass_100225 as SPR
18     ## import modelclass1_1 as Model
19    
20    
21     def initialize(data, parainfo):
22     global pfloat_name, pfloat0, pfix_name
23     # to seperate the parameters to fit and those fixed
24     pfloat_name, pfloat0, pfix_name, pfix = [], [], [], []
25     for i in parainfo:
26     if i['fixed'] == 1: # for those fixed parameters
27     pfix_name.append(i['name'])
28     pfix.append(i['value'])
29     elif i['fixed'] == 0: # for those float parameters
30     pfloat_name.append(i['name'])
31     pfloat0.append(i['value'])
32     else: print 'unreadable status.\n\
33     please check the parameter infomation!'
34    
35     print '-'*50
36     print 'The fixed parameters are:', pfix_name
37     print 'The fixed value:'
38     for i,j in enumerate(pfix_name):
39     print j,':',pfix[i]
40    
41     print '\n','-'*50
42     print 'The float parameters are:', pfloat_name
43     print 'The initial value:'
44     for i,j in enumerate(pfloat_name):
45     print j,':',pfloat0[i]
46    
47     print '\n','-'*50
48    
49    
50     # format examination
51     # to check if the input parameter matches the data to fit
52     n_curve = len(data)/2
53     for i in parainfo:
54     ntmp = i['number']
55     if ntmp != 1:
56     if ntmp != n_curve:
57     print 'the number of input curves doesn\'t match that of the \
58     parameter\n\
59     Please check the data and the parameters!'
60     break
61     else:
62     value = i['value']
63     if ntmp != len(value):
64     print 'Please check the infomation of the parameters!'
65     break
66    
67     pfloat0_1D, pfloat_name_1D = resizePfloat(pfloat0, pfloat_name, 0)
68     return pfloat0_1D, pfix, pfloat_name_1D, pfix_name
69    
70     def resizePfloat(p, pname, flag = 0):
71     # resize the pfloat list
72     # convertion between the two kinds of format:
73     # 1D: [p_a,p_b,p_c1,p_c2,p_c3,p_d]
74     # mixed: [p_a,p_b,[p_c1,p_c2,p_c3],p_d]
75    
76     # first examine the input
77     if len(p) != len(pname):
78     print 'The parameter list doesn\'t match the name list, please check again!'
79    
80     new_p, new_pname = [], []
81    
82     if flag == 0:
83     # from mixed list to one dim list:
84     for pi in p:
85     if type(pi) == list:
86     for j in pi:
87     new_p.append(j)
88     new_pname.append(pname[p.index(pi)])
89     else:
90     new_p.append(pi)
91     new_pname.append(pname[p.index(pi)])
92    
93     elif flag == 1:
94     # from one dim list back to mixed list:
95     ptmp = []
96     for i,pni in enumerate(pname):
97     n = pname.count(pni)
98     if n == 1 :
99     new_p.append(p[pname.index(pni)])
100     new_pname.append(pni)
101     else:
102     if pni not in new_pname:
103     new_pname.append(pni)
104     ptmp = p[i:i+n]
105     ptmp = list(ptmp)
106     new_p.append(ptmp)
107    
108     return new_p, new_pname
109    
110     def createdict(p,pname):
111     '''Pack the name and value of all the parameters into a dictionary list.'''
112     pdict = []
113     if len(p) == len(pname):
114     for i,j in enumerate(pname):
115     dictmp = {}
116     dictmp['name'] = j
117     dictmp['value'] = p[i]
118     pdict.append(dictmp)
119     return pdict
120     else:
121     print 'check the input.\n'
122     print pname
123     print p
124     return
125    
126     ## def keepsinglevalue(paratmp, i):
127     ## for n,p in enumerate(paratmp):
128     ## if type(p['value'])==list:
129     ## paralist[n]['value'] = p['value'][i]
130     ## else:
131     ## paralist[n]['value'] = p['value']
132     ## return paralist
133    
134    
135     def residuals(pfloat_1D, data, pfix):
136     '''Error function.'''
137     # the input of the float parameters are in one dimension
138     e = 0
139     pfloat, pfloat_name = resizePfloat(pfloat_1D, pfloat_name_1D, 1)
140    
141     p_all, pname_all = pfloat + pfix, pfloat_name + pfix_name
142     paratmp = createdict(p_all, pname_all)
143     paralist = copy.deepcopy(paratmp)
144    
145     for i in range(len(data)/2):
146     for n,p in enumerate(paratmp):
147     if type(p['value'])==list:
148     paralist[n]['value'] = p['value'][i]
149     else:
150     paralist[n]['value'] = p['value']
151    
152     t = data[i*2]
153     y = data[i*2+1]
154     e = e + (sprfunction(t,paralist) - y)**2
155     return e
156    
157    
158    
159     def lmafit(data,pfix):
160     '''Fitting.
161     '''
162     # The initial values of the parameters need to be fit are converted
163     # into a one dimension list through "resizePfloat" function as required
164     # by the leastsq input.
165    
166     p_1D, success = leastsq(residuals, pfloat0_1D, args=(data,pfix), maxfev=10000)
167     p, pfloat_name = resizePfloat(p_1D, pfloat_name_1D, 1)
168     print 'Success:', success
169     plotfit(data, p_1D, pfix)
170    
171     return p
172    
173    
174     def plotfit(data, pfit_1D, pfix):
175     '''Plotting of real and fitted data.
176     Creating titles with ssq error and parameters.
177     Printing the result and fitted parameters.'''
178     global pfit_name
179     pfit, pfit_name = resizePfloat(pfit_1D, pfloat_name_1D, 1)
180     txt = 'SumSqE: %1.8f' % sum(np.square(residuals(pfit_1D, data, pfix)))
181     print txt
182     plt.title(txt)
183     plt.xlabel('Time (s)')
184     plt.ylabel('Response (uRIU)')
185     plt.grid(True)
186    
187     p_all, pname_all = pfit + pfix, pfit_name + pfix_name
188     paratmp = createdict(p_all, pname_all)
189     paralist = copy.deepcopy(paratmp)
190    
191     for i in range(len(data)/2):
192     for n,p in enumerate(paratmp):
193     if type(p['value'])==list:
194     paralist[n]['value'] = p['value'][i]
195     else:
196     paralist[n]['value'] = p['value']
197     t = data[i*2]
198     y = data[i*2+1]
199     plt.plot(t, y, ',') # Plot real data
200     plt.plot(t, sprfunction(t, paralist)) # Plot fitted data
201     plt.show()
202    
203    
204     print '\n','-'*50
205     print 'The fitted parameters are:', pfit_name
206     print 'The fitted value:'
207     for i,j in enumerate(pfit_name):
208     print j,':',pfit[i]
209    
210     return
211    
212    
213     def updateparainfo(parainfo, pfit, pfit_name):
214     new_parainfo = copy.deepcopy(parainfo)
215     print '-'*50
216     print 'name\t initial value \t fitted value'
217     for i in new_parainfo:
218     for n,j in enumerate(pfit_name):
219     if i['name'] == j:
220     print j, '\t', i['value'], '\t', pfit[n]
221     i['value'] = pfit[n]
222    
223     return new_parainfo
224    
225    
226    
227     def fitting(dataobj, mobj):
228     '''The main function of curvefitting.
229     Input the data object and the model object.
230     '''
231     modelobj = copy.deepcopy(mobj)
232     global pfloat0_1D, pfix, pfloat_name_1D, pfix_name, sprfunction
233     data = dataobj.data
234     parainfo = modelobj.parainfo
235     pfloat0_1D, pfix, pfloat_name_1D, pfix_name = initialize(data, parainfo)
236     sprfunction = modelobj.function
237    
238     pfit = lmafit(data, pfix)
239     parainfo = updateparainfo(parainfo, pfit, pfloat_name)
240     print 'Do you want to update the model with fitted value?'
241     if str.upper(raw_input('y/n : ')) == 'Y':
242     modelobj.updatemodel(parainfo)
243     print 'The model is updated.'
244    
245     return pfit, modelobj