1 |
"""
|
2 |
CurveFitting: Perform the curve fitting for the data object.
|
3 |
|
4 |
Yuhang Wan
|
5 |
Last modified on 100427 (yymmdd) by YW
|
6 |
|
7 |
Typical Pipeline:(work together with "converter.py" and some model class
|
8 |
here take the basicmodel as example )
|
9 |
>----------load data-----------
|
10 |
>import converter as cv
|
11 |
>dat = cv.datafile_read("dat_example.obj")
|
12 |
>----------load model----------
|
13 |
>import basicmodel as bm
|
14 |
>m = bm.basicmodel()
|
15 |
>m = m.load_model('model_example.obj')
|
16 |
>---------do the fitting-------
|
17 |
>import curvefitting as cf
|
18 |
>pfit, m2 = cf.fitting(dat, m)
|
19 |
|
20 |
"""
|
21 |
__version__ = "100427"
|
22 |
|
23 |
import numpy as np
|
24 |
import pylab as plt
|
25 |
import copy
|
26 |
from scipy.optimize import leastsq
|
27 |
import time
|
28 |
## import packageClamp_100225 as Pack
|
29 |
## import SPRdataclass_100225 as SPR
|
30 |
## import modelclass1_1 as Model
|
31 |
|
32 |
|
33 |
def initialize(data, parainfo):
|
34 |
global pfloat_name, pfloat0, pfix_name
|
35 |
# to seperate the parameters to fit and those fixed
|
36 |
pfloat_name, pfloat0, pfix_name, pfix = [], [], [], []
|
37 |
for i in parainfo:
|
38 |
if i['fixed'] == 1: # for those fixed parameters
|
39 |
pfix_name.append(i['name'])
|
40 |
pfix.append(i['value'])
|
41 |
elif i['fixed'] == 0: # for those float parameters
|
42 |
pfloat_name.append(i['name'])
|
43 |
pfloat0.append(i['value'])
|
44 |
else: print 'unreadable status.\n\
|
45 |
please check the parameter infomation!'
|
46 |
|
47 |
print '--'*30
|
48 |
print 'The fixed parameters are:', pfix_name
|
49 |
print 'The fixed value:'
|
50 |
for i,j in enumerate(pfix_name):
|
51 |
print j,':',pfix[i]
|
52 |
|
53 |
print '\n','--'*30
|
54 |
print 'The float parameters are:', pfloat_name
|
55 |
print 'The initial value:'
|
56 |
for i,j in enumerate(pfloat_name):
|
57 |
print j,':',pfloat0[i]
|
58 |
|
59 |
print '\n','--'*30
|
60 |
|
61 |
|
62 |
# format examination
|
63 |
# to check if the input parameter matches the data to fit
|
64 |
n_curve = len(data)/2
|
65 |
for i in parainfo:
|
66 |
ntmp = i['number']
|
67 |
if ntmp != 1:
|
68 |
if ntmp != n_curve:
|
69 |
print 'the number of input curves doesn\'t match that of the \
|
70 |
parameter\n\
|
71 |
Please check the data and the parameters!'
|
72 |
break
|
73 |
else:
|
74 |
value = i['value']
|
75 |
if ntmp != len(value):
|
76 |
print 'Please check the infomation of the parameters!'
|
77 |
break
|
78 |
|
79 |
pfloat0_1D, pfloat_name_1D = resizePfloat(pfloat0, pfloat_name, 0)
|
80 |
return pfloat0_1D, pfix, pfloat_name_1D, pfix_name
|
81 |
|
82 |
def resizePfloat(p, pname, flag = 0):
|
83 |
# resize the pfloat list
|
84 |
# convertion between the two kinds of format:
|
85 |
# 1D: [p_a,p_b,p_c1,p_c2,p_c3,p_d]
|
86 |
# mixed: [p_a,p_b,[p_c1,p_c2,p_c3],p_d]
|
87 |
|
88 |
# first examine the input
|
89 |
if len(p) != len(pname):
|
90 |
print 'The parameter list doesn\'t match the name list, please check again!'
|
91 |
|
92 |
new_p, new_pname = [], []
|
93 |
|
94 |
if flag == 0:
|
95 |
# from mixed list to one dim list:
|
96 |
for pi in p:
|
97 |
if type(pi) == list:
|
98 |
for j in pi:
|
99 |
new_p.append(j)
|
100 |
new_pname.append(pname[p.index(pi)])
|
101 |
else:
|
102 |
new_p.append(pi)
|
103 |
new_pname.append(pname[p.index(pi)])
|
104 |
|
105 |
elif flag == 1:
|
106 |
# from one dim list back to mixed list:
|
107 |
ptmp = []
|
108 |
for i,pni in enumerate(pname):
|
109 |
n = pname.count(pni)
|
110 |
if n == 1 :
|
111 |
new_p.append(p[pname.index(pni)])
|
112 |
new_pname.append(pni)
|
113 |
else:
|
114 |
if pni not in new_pname:
|
115 |
new_pname.append(pni)
|
116 |
ptmp = p[i:i+n]
|
117 |
ptmp = list(ptmp)
|
118 |
new_p.append(ptmp)
|
119 |
|
120 |
return new_p, new_pname
|
121 |
|
122 |
def createdict(p,pname):
|
123 |
'''Pack the name and value of all the parameters into a dictionary list.'''
|
124 |
pdict = []
|
125 |
if len(p) == len(pname):
|
126 |
for i,j in enumerate(pname):
|
127 |
dictmp = {}
|
128 |
dictmp['name'] = j
|
129 |
dictmp['value'] = p[i]
|
130 |
pdict.append(dictmp)
|
131 |
return pdict
|
132 |
else:
|
133 |
print 'check the input.\n'
|
134 |
print pname
|
135 |
print p
|
136 |
return
|
137 |
|
138 |
## def keepsinglevalue(paratmp, i):
|
139 |
## for n,p in enumerate(paratmp):
|
140 |
## if type(p['value'])==list:
|
141 |
## paralist[n]['value'] = p['value'][i]
|
142 |
## else:
|
143 |
## paralist[n]['value'] = p['value']
|
144 |
## return paralist
|
145 |
|
146 |
|
147 |
def residuals(pfloat_1D, data, pfix):
|
148 |
'''Error function.'''
|
149 |
# the input of the float parameters are in one dimension
|
150 |
e = 0
|
151 |
pfloat, pfloat_name = resizePfloat(pfloat_1D, pfloat_name_1D, 1)
|
152 |
|
153 |
p_all, pname_all = pfloat + pfix, pfloat_name + pfix_name
|
154 |
paratmp = createdict(p_all, pname_all)
|
155 |
paralist = copy.deepcopy(paratmp)
|
156 |
|
157 |
for i in range(len(data)/2):
|
158 |
for n,p in enumerate(paratmp):
|
159 |
if type(p['value'])==list:
|
160 |
paralist[n]['value'] = p['value'][i]
|
161 |
else:
|
162 |
paralist[n]['value'] = p['value']
|
163 |
|
164 |
t = data[i*2]
|
165 |
y = data[i*2+1]
|
166 |
e = e + (sprfunction(t,paralist) - y)**2
|
167 |
return e
|
168 |
|
169 |
|
170 |
|
171 |
def lmafit(data,pfix):
|
172 |
'''Fitting.
|
173 |
'''
|
174 |
# The initial values of the parameters need to be fit are converted
|
175 |
# into a one dimension list through "resizePfloat" function as required
|
176 |
# by the leastsq input.
|
177 |
|
178 |
p_1D, success = leastsq(residuals, pfloat0_1D, args=(data,pfix), maxfev=10000)
|
179 |
p, pfloat_name = resizePfloat(p_1D, pfloat_name_1D, 1)
|
180 |
print 'Success:', success
|
181 |
plotfit(data, p_1D, pfix)
|
182 |
|
183 |
return p
|
184 |
|
185 |
|
186 |
def plotfit(data, pfit_1D, pfix):
|
187 |
'''Plotting of real and fitted data.
|
188 |
Creating titles with ssq error and parameters.
|
189 |
Printing the result and fitted parameters.'''
|
190 |
global pfit_name
|
191 |
pfit, pfit_name = resizePfloat(pfit_1D, pfloat_name_1D, 1)
|
192 |
txt = 'SumSqE: %1.8f' % sum(np.square(residuals(pfit_1D, data, pfix)))
|
193 |
print txt
|
194 |
plt.title(txt)
|
195 |
plt.xlabel('Time (s)')
|
196 |
plt.ylabel('Response (uRIU)')
|
197 |
plt.grid(True)
|
198 |
|
199 |
p_all, pname_all = pfit + pfix, pfit_name + pfix_name
|
200 |
paratmp = createdict(p_all, pname_all)
|
201 |
paralist = copy.deepcopy(paratmp)
|
202 |
|
203 |
for i in range(len(data)/2):
|
204 |
for n,p in enumerate(paratmp):
|
205 |
if type(p['value'])==list:
|
206 |
paralist[n]['value'] = p['value'][i]
|
207 |
else:
|
208 |
paralist[n]['value'] = p['value']
|
209 |
t = data[i*2]
|
210 |
y = data[i*2+1]
|
211 |
plt.plot(t, y, ',') # Plot real data
|
212 |
plt.plot(t, sprfunction(t, paralist)) # Plot fitted data
|
213 |
plt.show()
|
214 |
|
215 |
print '\n','-'*60
|
216 |
print 'The fitted parameters are:', pfit_name
|
217 |
print 'The fitted value:'
|
218 |
for i,j in enumerate(pfit_name):
|
219 |
print j,':',pfit[i]
|
220 |
|
221 |
return
|
222 |
|
223 |
|
224 |
def updateparainfo(parainfo, pfit, pfit_name):
|
225 |
new_parainfo = copy.deepcopy(parainfo)
|
226 |
|
227 |
print '\n', '+'*25, "Comparison", '+'*25
|
228 |
## print '-'*60
|
229 |
print 'name\t\t initial value \t\t fitted value'
|
230 |
for i in new_parainfo:
|
231 |
for n,j in enumerate(pfit_name):
|
232 |
if i['name'] == j:
|
233 |
print j, '\t\t ', i['value'], ' \t\t ', pfit[n]
|
234 |
i['value'] = pfit[n]
|
235 |
|
236 |
return new_parainfo
|
237 |
|
238 |
|
239 |
|
240 |
def fitting(dataobj, mobj):
|
241 |
'''The main function of curvefitting.
|
242 |
Input the data object and the model object.
|
243 |
'''
|
244 |
modelobj = copy.deepcopy(mobj)
|
245 |
time1 = time.time()
|
246 |
global pfloat0_1D, pfix, pfloat_name_1D, pfix_name, sprfunction
|
247 |
data = dataobj.data
|
248 |
parainfo = modelobj.parainfo
|
249 |
pfloat0_1D, pfix, pfloat_name_1D, pfix_name = initialize(data, parainfo)
|
250 |
sprfunction = modelobj.function
|
251 |
|
252 |
pfit = lmafit(data, pfix)
|
253 |
parainfo = updateparainfo(parainfo, pfit, pfloat_name)
|
254 |
|
255 |
time2 = time.time()
|
256 |
print '-'*60
|
257 |
print "Time elapsed: %s seconds\n" %(time2-time1)
|
258 |
print 'Do you want to update the model with fitted value?'
|
259 |
if str.upper(raw_input('y/n : ')) == 'Y':
|
260 |
modelobj.updatemodel(parainfo)
|
261 |
print 'The model is updated.'
|
262 |
|
263 |
return pfit, modelobj |