ViewVC Help
View File | Revision Log | Show Annotations | View Changeset | Root Listing
root/osprai/osprai/trunk/fit_module.py
Revision: 28
Committed: Wed May 19 06:06:07 2010 UTC (9 years, 3 months ago) by clausted
File size: 9619 byte(s)
Log Message:
Changes to params dictionary.  Valid values for 'fixed' are 'fixed' 'float' or an integer.  True and False are no longer acceptable.  Now mutli-ROI curve fitting seems to work!
Line File contents
1 """
2 fit: Curve-fitting module for SPRI data.
3 Christopher Lausted, Institute for Systems Biology,
4 OSPRAI developers
5 Last modified on 100518 (yymmdd)
6
7 Example:
8 #import fit_module as fit
9 #import mdl_module as mdl
10 #ba1.roi[0].model = mdl.drift
11 #ba1.roi[0].params = {'rate': {'value':1, 'min':-100.0, 'max':100.0, 'fixed':'float'} }
12 #success = lma(ba1.roi[0])
13 #for i in ba1.roi: lma(i)
14 """
15 __version__ = "100518"
16
17
18 ## Import libraries
19 import ba_class as ba
20 from scipy.optimize import leastsq
21 from copy import deepcopy
22 from numpy import log, exp, tanh, arctanh
23 from numpy import sum, hstack, zeros
24
25
26 def lma(roi):
27 """
28 Normal Levenberg-Marquart fitting based on SciPy.
29 This function takes a single ba_class RegionOfInterest object (roi).
30 It modifies roi.params in place. It returns the number of iterations completed.
31 """
32 params = roi.params
33 checkparams(params)
34 ## Create list initial values for the floating parameters.
35 p0 = [dict['value'] for key,dict in params.iteritems() if (dict['fixed'] != True)]
36 ## Error function.
37 erf = lambda p: (model_interface(roi, p) - roi.value)
38 ## Fit.
39 #p1, success = leastsq(erf, p0, args=(), maxfev=999)
40 p1, success = leastsq(erf, p0, maxfev=500)
41 return success
42 ## End of constlma() function.
43
44
45 def model_interface(roi, pfloat):
46 """
47 Use the model assigned to this roi to simulation data vs time.
48 Provide a list of values for the floating parameters (pfloat).
49 This function will write the values to the parameter dictionary (roi.params).
50 This allows the LMA fitter to adjust the parameters in place.
51 """
52 for i,key in enumerate(roi.params):
53 if (roi.params[key]['fixed'] != True):
54 roi.params[key]['value'] = pfloat[i]
55 return roi.model(roi.time, roi.value, roi.params)
56
57
58 def checkparams(params):
59 """
60 Check that each dictionary in the params dictionary contains the four
61 keys (value, min, max, fixed) and add them if necessary. Default is 'fixed':'fixed'.
62 Also check that when parameters float, min<value<max.
63 """
64 flag = False
65 for key,dict in params.iteritems():
66 ## First check that keys are there.
67 if ('value' not in dict.keys()): dict['value'] = 0
68 if ('min' not in dict.keys()): dict['min'] = dict['value']
69 if ('max' not in dict.keys()): dict['max'] = dict['value']
70 if ('fixed' not in dict.keys()): dict['fixed'] = 'fixed'
71 ## Check that when parameters float, min<value<max.
72 if (dict['fixed'] == 'float'):
73 if (dict['min'] >= dict['max']):
74 ## Set max to be 1+min.
75 dict['max'] = float(dict['min']) + 1.0
76 flag = True
77 if (dict['value'] <= dict['min']) or (dict['value'] >= dict['max']):
78 ## Set initial value halfway between min and max.
79 dict['value'] = 0.5 * (dict['min'] + dict['max'])
80 flag = True
81 if (flag == True):
82 print "Parameter min/max errors: Modifications were made automatically."
83 return
84 ## End of checkparams() function.
85
86
87 #### The code below is a variant allowing parameter constraints ####
88 def transform(a, b, x0):
89 """
90 A transformation to convert x from range (-inf,+inf) to (a,b).
91 This is used to help constrain the outputs of the Levenberg-Marquart algorithm.
92 """
93 ## Considered transform: x^2 for (0,+inf)
94 ## Considered transform: a*tanh(x) for (-a,+a)
95 ## Considered transform: a+((b-a)/(1+exp(-x)) for (a,b) but it has problems.
96 #x = max(-x0, -709) ## Function math.exp(>709) overflows.
97 #x1 = a + ( (b-a) / (1+exp(-x) ))
98 ## Try transform: y = ((b-a)*tanh(x)+b+a)/2
99 x1 = ( (b-a)*tanh(x0) + b + a ) / 2
100 return x1
101 ## End of transform() function.
102
103
104 def itransform(a, b, x1):
105 """
106 The (inverse) transformation to convert x from range(a,b) to (-inf,+inf).
107 This is used to help constrain the outputs of the Levenberg-Marquart algorithm.
108 """
109 #x0 = -log( ((b-a)/(x1-a)) - 1 )
110 x0 = arctanh( (2*x1-b-a) / (b-a) )
111 return x0
112 ## End of itransform() function.
113
114
115 def clma(roi):
116 """
117 Constrained Levenberg-Marquart Algorithm fitting based on SciPy.
118 This function takes a single ba_class RegionOfInterest object (roi).
119 It modifies roi.params in place. It returns the number of iterations completed.
120 """
121 params = roi.params
122 checkparams(params)
123
124 ## Create list of initial values for the floating parameters.
125 p0 = [dict['value'] for key,dict in params.iteritems() if (dict['fixed'] != True)]
126 ## Adjust list based on min/max constraints.
127 for i,key in enumerate(params):
128 if (params[key]['fixed'] != True):
129 a, b = params[key]['min'], params[key]['max']
130 p0[i] = itransform(a, b, float(p0[i]))
131 # Temp
132 print key, params[key]['value'], p0[i]
133
134 ## Fit.
135 p1, success = leastsq(cerf, p0, args=(roi), maxfev=999)
136 return success
137 ## End of constrained lma() function.
138
139
140 def cerf(pfloat, roi):
141 """
142 Error estimation function for use with constrained LMA.
143 Use the model assigned to this roi to simulation data vs time.
144 Provide a list of values for the floating parameters (pfloat).
145 This list is adjusted based on min/max constraints in the parameter dictionary.
146 This function will write the values to the parameter dictionary (roi.params).
147 This allows the LMA fitter to adjust the parameters in place.
148 """
149 params = roi.params
150 for i,key in enumerate(params):
151 if (params[key]['fixed'] != True):
152 a, b = params[key]['min'], params[key]['max']
153 params[key]['value'] = transform(a, b, pfloat[i])
154 print "%s %.2e" % (key, params[key]['value']), # Temp.
155 erf = (roi.model(roi.time, roi.value, params) - roi.value)
156 print "sum %0.1f" % np.sum(erf) # Temp.
157 return erf
158 ## End of constrained lma erf() function.
159
160
161 def mclma(rois):
162 """
163 Constrained Levenberg-Marquart Algorithm fitting based on SciPy.
164 This function takes a single ba_class RegionOfInterest object (roi).
165 It modifies roi.params in place. It returns the number of iterations completed.
166 """
167 pval = [] ## List of values for floating parameters needed by leastsq().
168
169 for roi in rois:
170 ## Check that pd contains all 4 keys: value, min, max, fixed.
171 checkparams(roi.params)
172 ## For each parameter with its dictionary...
173 for pkey,pd in roi.params.iteritems():
174 if (pd['fixed'] == 'fixed'):
175 ## Fixed parameter.
176 pass
177 elif (pd['fixed'] == 'float'):
178 ## Floating parameter.
179 pval.append(pd['value']) ## e.g. 1e5
180 ## Perform transform based on min/max constraints.
181 a, b = float(pd['min']), float(pd['max'])
182 pval[-1] = itransform(a, b, float(pval[-1]))
183 ## Print out intial guess and its transform. Temporary.
184 print pkey, pd['value'], pval[-1]
185 else:
186 ## Shared, fixed to a floating parameter in another ROI.
187 pass
188
189 ## Fit.
190 p1, success = leastsq(mcerf, pval, args=(rois), maxfev=999)
191 return success
192 ## End of multi-roi constrained lma() function.
193
194
195 def mcerf(pfloat, rois):
196 """
197 Error estimation function for use with constrained LMA.
198 Use the model assigned to this roi to simulation data vs time.
199 Provide a list of values for the floating parameters (pfloat).
200 This list is adjusted based on min/max constraints in the parameter dictionary.
201 This function will write the values to the parameter dictionary (roi.params).
202 This allows the LMA fitter to adjust the parameters in place.
203 """
204 i = -1 ## Index for pfloat.
205
206 ## Insert the new guesses from pfloat into ROI parameters.
207 ## First, get all of the floating parameters done.
208 for j,roi in enumerate(rois):
209 print j, # Temp
210 ## For each parameter with its dictionary...
211 for pkey,pd in roi.params.iteritems():
212 if (pd['fixed'] == 'float'):
213 ## Floating parameter. Transform back.
214 i += 1
215 a, b = float(pd['min']), float(pd['max'])
216 pd['value'] = transform(a, b, pfloat[i])
217 print "%s %.4f" % (pkey, pd['value']), # Temp.
218
219 ## Second, do the floating parameters shared across ROIs.
220 for j,roi in enumerate(rois):
221 for pkey,pd in roi.params.iteritems():
222 if (pd['fixed'] == 'fixed') or (pd['fixed'] == 'float'):
223 pass
224 else:
225 ## Shared--fixed to a floating parameter in another ROI.
226 refroi = int(pd['fixed'])
227 pd['value'] = rois[refroi].params[pkey]['value']
228
229 ## Calculate model error function for each roi. Then concatenate them.
230 erf = zeros(0)
231 for roi in rois:
232 e = roi.model(roi.time, roi.value, roi.params) - roi.value
233 erf = hstack( (erf, e) ) ## Concatenation of nparrays.
234 print "ssq %0.1f" % sum(erf**2) # Temp.
235
236 return erf
237 ## End of multi-roi constrained lma erf() function.
238
239
240 ################################# End of module #################################