ViewVC Help
View File | Revision Log | Show Annotations | View Changeset | Root Listing
root/osprai/osprai/trunk/fit_module.py
Revision: 28
Committed: Wed May 19 06:06:07 2010 UTC (9 years, 5 months ago) by clausted
File size: 9619 byte(s)
Log Message:
Changes to params dictionary.  Valid values for 'fixed' are 'fixed' 'float' or an integer.  True and False are no longer acceptable.  Now mutli-ROI curve fitting seems to work!
Line User Rev File contents
1 clausted 19 """
2     fit: Curve-fitting module for SPRI data.
3     Christopher Lausted, Institute for Systems Biology,
4     OSPRAI developers
5 clausted 27 Last modified on 100518 (yymmdd)
6 clausted 19
7     Example:
8     #import fit_module as fit
9     #import mdl_module as mdl
10     #ba1.roi[0].model = mdl.drift
11 clausted 28 #ba1.roi[0].params = {'rate': {'value':1, 'min':-100.0, 'max':100.0, 'fixed':'float'} }
12 clausted 19 #success = lma(ba1.roi[0])
13     #for i in ba1.roi: lma(i)
14     """
15 clausted 27 __version__ = "100518"
16 clausted 19
17    
18     ## Import libraries
19     import ba_class as ba
20     from scipy.optimize import leastsq
21     from copy import deepcopy
22 clausted 20 from numpy import log, exp, tanh, arctanh
23 clausted 27 from numpy import sum, hstack, zeros
24 clausted 19
25    
26     def lma(roi):
27     """
28     Normal Levenberg-Marquart fitting based on SciPy.
29     This function takes a single ba_class RegionOfInterest object (roi).
30     It modifies roi.params in place. It returns the number of iterations completed.
31     """
32     params = roi.params
33     checkparams(params)
34     ## Create list initial values for the floating parameters.
35     p0 = [dict['value'] for key,dict in params.iteritems() if (dict['fixed'] != True)]
36     ## Error function.
37 clausted 20 erf = lambda p: (model_interface(roi, p) - roi.value)
38 clausted 19 ## Fit.
39 clausted 20 #p1, success = leastsq(erf, p0, args=(), maxfev=999)
40     p1, success = leastsq(erf, p0, maxfev=500)
41 clausted 19 return success
42     ## End of constlma() function.
43    
44    
45 clausted 20 def model_interface(roi, pfloat):
46 clausted 19 """
47     Use the model assigned to this roi to simulation data vs time.
48     Provide a list of values for the floating parameters (pfloat).
49     This function will write the values to the parameter dictionary (roi.params).
50     This allows the LMA fitter to adjust the parameters in place.
51     """
52     for i,key in enumerate(roi.params):
53     if (roi.params[key]['fixed'] != True):
54     roi.params[key]['value'] = pfloat[i]
55 clausted 20 return roi.model(roi.time, roi.value, roi.params)
56 clausted 19
57    
58     def checkparams(params):
59     """
60     Check that each dictionary in the params dictionary contains the four
61 clausted 28 keys (value, min, max, fixed) and add them if necessary. Default is 'fixed':'fixed'.
62 clausted 20 Also check that when parameters float, min<value<max.
63 clausted 19 """
64 clausted 20 flag = False
65 clausted 19 for key,dict in params.iteritems():
66 clausted 20 ## First check that keys are there.
67 clausted 19 if ('value' not in dict.keys()): dict['value'] = 0
68     if ('min' not in dict.keys()): dict['min'] = dict['value']
69     if ('max' not in dict.keys()): dict['max'] = dict['value']
70 clausted 28 if ('fixed' not in dict.keys()): dict['fixed'] = 'fixed'
71 clausted 20 ## Check that when parameters float, min<value<max.
72 clausted 28 if (dict['fixed'] == 'float'):
73 clausted 20 if (dict['min'] >= dict['max']):
74     ## Set max to be 1+min.
75     dict['max'] = float(dict['min']) + 1.0
76     flag = True
77     if (dict['value'] <= dict['min']) or (dict['value'] >= dict['max']):
78     ## Set initial value halfway between min and max.
79     dict['value'] = 0.5 * (dict['min'] + dict['max'])
80     flag = True
81     if (flag == True):
82     print "Parameter min/max errors: Modifications were made automatically."
83 clausted 19 return
84     ## End of checkparams() function.
85    
86    
87 clausted 28 #### The code below is a variant allowing parameter constraints ####
88 clausted 20 def transform(a, b, x0):
89     """
90     A transformation to convert x from range (-inf,+inf) to (a,b).
91     This is used to help constrain the outputs of the Levenberg-Marquart algorithm.
92     """
93     ## Considered transform: x^2 for (0,+inf)
94     ## Considered transform: a*tanh(x) for (-a,+a)
95     ## Considered transform: a+((b-a)/(1+exp(-x)) for (a,b) but it has problems.
96     #x = max(-x0, -709) ## Function math.exp(>709) overflows.
97     #x1 = a + ( (b-a) / (1+exp(-x) ))
98     ## Try transform: y = ((b-a)*tanh(x)+b+a)/2
99     x1 = ( (b-a)*tanh(x0) + b + a ) / 2
100     return x1
101     ## End of transform() function.
102    
103    
104     def itransform(a, b, x1):
105     """
106     The (inverse) transformation to convert x from range(a,b) to (-inf,+inf).
107     This is used to help constrain the outputs of the Levenberg-Marquart algorithm.
108     """
109     #x0 = -log( ((b-a)/(x1-a)) - 1 )
110     x0 = arctanh( (2*x1-b-a) / (b-a) )
111     return x0
112     ## End of itransform() function.
113    
114    
115     def clma(roi):
116     """
117     Constrained Levenberg-Marquart Algorithm fitting based on SciPy.
118     This function takes a single ba_class RegionOfInterest object (roi).
119     It modifies roi.params in place. It returns the number of iterations completed.
120     """
121     params = roi.params
122     checkparams(params)
123    
124     ## Create list of initial values for the floating parameters.
125     p0 = [dict['value'] for key,dict in params.iteritems() if (dict['fixed'] != True)]
126     ## Adjust list based on min/max constraints.
127     for i,key in enumerate(params):
128     if (params[key]['fixed'] != True):
129     a, b = params[key]['min'], params[key]['max']
130     p0[i] = itransform(a, b, float(p0[i]))
131     # Temp
132     print key, params[key]['value'], p0[i]
133    
134     ## Fit.
135     p1, success = leastsq(cerf, p0, args=(roi), maxfev=999)
136     return success
137     ## End of constrained lma() function.
138    
139    
140     def cerf(pfloat, roi):
141     """
142     Error estimation function for use with constrained LMA.
143     Use the model assigned to this roi to simulation data vs time.
144     Provide a list of values for the floating parameters (pfloat).
145     This list is adjusted based on min/max constraints in the parameter dictionary.
146     This function will write the values to the parameter dictionary (roi.params).
147     This allows the LMA fitter to adjust the parameters in place.
148     """
149     params = roi.params
150     for i,key in enumerate(params):
151     if (params[key]['fixed'] != True):
152     a, b = params[key]['min'], params[key]['max']
153     params[key]['value'] = transform(a, b, pfloat[i])
154 clausted 27 print "%s %.2e" % (key, params[key]['value']), # Temp.
155 clausted 20 erf = (roi.model(roi.time, roi.value, params) - roi.value)
156     print "sum %0.1f" % np.sum(erf) # Temp.
157     return erf
158     ## End of constrained lma erf() function.
159    
160 clausted 27
161     def mclma(rois):
162 clausted 20 """
163 clausted 27 Constrained Levenberg-Marquart Algorithm fitting based on SciPy.
164 clausted 20 This function takes a single ba_class RegionOfInterest object (roi).
165     It modifies roi.params in place. It returns the number of iterations completed.
166     """
167 clausted 27 pval = [] ## List of values for floating parameters needed by leastsq().
168    
169     for roi in rois:
170     ## Check that pd contains all 4 keys: value, min, max, fixed.
171     checkparams(roi.params)
172     ## For each parameter with its dictionary...
173     for pkey,pd in roi.params.iteritems():
174 clausted 28 if (pd['fixed'] == 'fixed'):
175 clausted 27 ## Fixed parameter.
176     pass
177 clausted 28 elif (pd['fixed'] == 'float'):
178 clausted 27 ## Floating parameter.
179     pval.append(pd['value']) ## e.g. 1e5
180     ## Perform transform based on min/max constraints.
181     a, b = float(pd['min']), float(pd['max'])
182     pval[-1] = itransform(a, b, float(pval[-1]))
183     ## Print out intial guess and its transform. Temporary.
184     print pkey, pd['value'], pval[-1]
185     else:
186     ## Shared, fixed to a floating parameter in another ROI.
187     pass
188    
189 clausted 20 ## Fit.
190 clausted 27 p1, success = leastsq(mcerf, pval, args=(rois), maxfev=999)
191 clausted 20 return success
192 clausted 27 ## End of multi-roi constrained lma() function.
193 clausted 20
194    
195 clausted 27 def mcerf(pfloat, rois):
196 clausted 20 """
197 clausted 27 Error estimation function for use with constrained LMA.
198     Use the model assigned to this roi to simulation data vs time.
199 clausted 20 Provide a list of values for the floating parameters (pfloat).
200 clausted 27 This list is adjusted based on min/max constraints in the parameter dictionary.
201     This function will write the values to the parameter dictionary (roi.params).
202     This allows the LMA fitter to adjust the parameters in place.
203 clausted 20 """
204 clausted 27 i = -1 ## Index for pfloat.
205    
206     ## Insert the new guesses from pfloat into ROI parameters.
207     ## First, get all of the floating parameters done.
208     for j,roi in enumerate(rois):
209     print j, # Temp
210     ## For each parameter with its dictionary...
211     for pkey,pd in roi.params.iteritems():
212 clausted 28 if (pd['fixed'] == 'float'):
213     ## Floating parameter. Transform back.
214 clausted 27 i += 1
215     a, b = float(pd['min']), float(pd['max'])
216     pd['value'] = transform(a, b, pfloat[i])
217     print "%s %.4f" % (pkey, pd['value']), # Temp.
218    
219     ## Second, do the floating parameters shared across ROIs.
220     for j,roi in enumerate(rois):
221     for pkey,pd in roi.params.iteritems():
222 clausted 28 if (pd['fixed'] == 'fixed') or (pd['fixed'] == 'float'):
223 clausted 27 pass
224 clausted 28 else:
225     ## Shared--fixed to a floating parameter in another ROI.
226 clausted 27 refroi = int(pd['fixed'])
227     pd['value'] = rois[refroi].params[pkey]['value']
228    
229     ## Calculate model error function for each roi. Then concatenate them.
230     erf = zeros(0)
231     for roi in rois:
232     e = roi.model(roi.time, roi.value, roi.params) - roi.value
233 clausted 28 erf = hstack( (erf, e) ) ## Concatenation of nparrays.
234     print "ssq %0.1f" % sum(erf**2) # Temp.
235 clausted 27
236     return erf
237     ## End of multi-roi constrained lma erf() function.
238    
239    
240 clausted 19 ################################# End of module #################################