ViewVC Help
View File | Revision Log | Show Annotations | View Changeset | Root Listing
root/osprai/osprai/trunk/fit_module.py
Revision: 20
Committed: Mon Apr 26 21:59:45 2010 UTC (9 years, 4 months ago) by clausted
File size: 8431 byte(s)
Log Message:
Added clma(), a constrained LMA fitting function.  We can now set a lower and upper bound for fitting.  Be sure that the initial estimate is between, not at, the bounds.  It transforms the parameters using x1 = ((b-a)*tanh(x0)+b+a)/2 for the interval (a,b).  
Line File contents
1 """
2 fit: Curve-fitting module for SPRI data.
3 Christopher Lausted, Institute for Systems Biology,
4 OSPRAI developers
5 Last modified on 100425 (yymmdd)
6
7 Example:
8 #import fit_module as fit
9 #import mdl_module as mdl
10 #ba1.roi[0].model = mdl.drift
11 #ba1.roi[0].params = {'rate': {'value':1, 'min':-100.0, 'max':100.0, 'fixed':False} }
12 #success = lma(ba1.roi[0])
13 #for i in ba1.roi: lma(i)
14 """
15 __version__ = "100425"
16
17
18 ## Import libraries
19 import ba_class as ba
20 import numpy as np
21 from scipy.optimize import leastsq
22 from copy import deepcopy
23 from numpy import log, exp, tanh, arctanh
24 #from math import exp
25
26
27 def lma(roi):
28 """
29 Normal Levenberg-Marquart fitting based on SciPy.
30 This function takes a single ba_class RegionOfInterest object (roi).
31 It modifies roi.params in place. It returns the number of iterations completed.
32 """
33 params = roi.params
34 checkparams(params)
35 ## Create list initial values for the floating parameters.
36 p0 = [dict['value'] for key,dict in params.iteritems() if (dict['fixed'] != True)]
37 ## Error function.
38 erf = lambda p: (model_interface(roi, p) - roi.value)
39 ## Fit.
40 #p1, success = leastsq(erf, p0, args=(), maxfev=999)
41 p1, success = leastsq(erf, p0, maxfev=500)
42 return success
43 ## End of constlma() function.
44
45
46 def model_interface(roi, pfloat):
47 """
48 Use the model assigned to this roi to simulation data vs time.
49 Provide a list of values for the floating parameters (pfloat).
50 This function will write the values to the parameter dictionary (roi.params).
51 This allows the LMA fitter to adjust the parameters in place.
52 """
53 for i,key in enumerate(roi.params):
54 if (roi.params[key]['fixed'] != True):
55 roi.params[key]['value'] = pfloat[i]
56 return roi.model(roi.time, roi.value, roi.params)
57
58
59 def checkparams(params):
60 """
61 Check that each dictionary in the params dictionary contains the four
62 keys (value, min, max, fixed) and add them if necessary. Default is 'fixed':True.
63 Also check that when parameters float, min<value<max.
64 """
65 flag = False
66 for key,dict in params.iteritems():
67 ## First check that keys are there.
68 if ('value' not in dict.keys()): dict['value'] = 0
69 if ('min' not in dict.keys()): dict['min'] = dict['value']
70 if ('max' not in dict.keys()): dict['max'] = dict['value']
71 if ('fixed' not in dict.keys()): dict['fixed'] = True
72 ## Check that when parameters float, min<value<max.
73 if (dict['fixed'] == False):
74 if (dict['min'] >= dict['max']):
75 ## Set max to be 1+min.
76 dict['max'] = float(dict['min']) + 1.0
77 flag = True
78 if (dict['value'] <= dict['min']) or (dict['value'] >= dict['max']):
79 ## Set initial value halfway between min and max.
80 dict['value'] = 0.5 * (dict['min'] + dict['max'])
81 flag = True
82 if (flag == True):
83 print "Parameter min/max errors: Modifications were made automatically."
84 return
85 ## End of checkparams() function.
86
87
88 #### The code below is an alternative, non-working strategy. ####
89 def transform(a, b, x0):
90 """
91 A transformation to convert x from range (-inf,+inf) to (a,b).
92 This is used to help constrain the outputs of the Levenberg-Marquart algorithm.
93 """
94 ## Considered transform: x^2 for (0,+inf)
95 ## Considered transform: a*tanh(x) for (-a,+a)
96 ## Considered transform: a+((b-a)/(1+exp(-x)) for (a,b) but it has problems.
97 #x = max(-x0, -709) ## Function math.exp(>709) overflows.
98 #x1 = a + ( (b-a) / (1+exp(-x) ))
99 ## Try transform: y = ((b-a)*tanh(x)+b+a)/2
100 x1 = ( (b-a)*tanh(x0) + b + a ) / 2
101 return x1
102 ## End of transform() function.
103
104
105 def itransform(a, b, x1):
106 """
107 The (inverse) transformation to convert x from range(a,b) to (-inf,+inf).
108 This is used to help constrain the outputs of the Levenberg-Marquart algorithm.
109 """
110 #x0 = -log( ((b-a)/(x1-a)) - 1 )
111 x0 = arctanh( (2*x1-b-a) / (b-a) )
112 return x0
113 ## End of itransform() function.
114
115
116 ## Warning: The next two functions are not yet working.
117 def clma(roi):
118 """
119 Constrained Levenberg-Marquart Algorithm fitting based on SciPy.
120 This function takes a single ba_class RegionOfInterest object (roi).
121 It modifies roi.params in place. It returns the number of iterations completed.
122 """
123 params = roi.params
124 checkparams(params)
125
126 ## Create list of initial values for the floating parameters.
127 p0 = [dict['value'] for key,dict in params.iteritems() if (dict['fixed'] != True)]
128 ## Adjust list based on min/max constraints.
129 for i,key in enumerate(params):
130 if (params[key]['fixed'] != True):
131 a, b = params[key]['min'], params[key]['max']
132 p0[i] = itransform(a, b, float(p0[i]))
133 # Temp
134 print key, params[key]['value'], p0[i]
135
136 ## Fit.
137 p1, success = leastsq(cerf, p0, args=(roi), maxfev=999)
138 return success
139 ## End of constrained lma() function.
140
141
142 def cerf(pfloat, roi):
143 """
144 Error estimation function for use with constrained LMA.
145 Use the model assigned to this roi to simulation data vs time.
146 Provide a list of values for the floating parameters (pfloat).
147 This list is adjusted based on min/max constraints in the parameter dictionary.
148 This function will write the values to the parameter dictionary (roi.params).
149 This allows the LMA fitter to adjust the parameters in place.
150 """
151 params = roi.params
152 for i,key in enumerate(params):
153 if (params[key]['fixed'] != True):
154 a, b = params[key]['min'], params[key]['max']
155 params[key]['value'] = transform(a, b, pfloat[i])
156 print "%s %f" % (key, params[key]['value']), # Temp.
157 erf = (roi.model(roi.time, roi.value, params) - roi.value)
158 print "sum %0.1f" % np.sum(erf) # Temp.
159 return erf
160 ## End of constrained lma erf() function.
161
162
163 '''
164 #### The code below is an alternative strategy that is not very robust. ####
165 def clma_old(roi):
166 """
167 Constrained Levenberg-Marquart fitting based on SciPy.
168 This function takes a single ba_class RegionOfInterest object (roi).
169 It modifies roi.params in place. It returns the number of iterations completed.
170 The error estimate is increased when a parameter moves outside the given bounds.
171 """
172 params = roi.params
173 checkparams(params)
174 ## Create list initial values for the floating parameters.
175 p0 = [dict['value'] for key,dict in params.iteritems() if (dict['fixed'] != True)]
176 ## Fit.
177 p1, success = leastsq(cerf_old, p0, args=(roi), maxfev=10000)
178 return success
179 ## End of clma() function.
180
181
182 def cerf_old(pfloat, roi):
183 """
184 Constraining Error Function.
185 Use the model assigned to this roi to simulation data vs time and calculate the error.
186 This allows the Constrained LMA fitter to adjust the parameters in place.
187 Provide a list of values for the floating parameters (pfloat).
188 The penalty factor is 1*(distance outside bounds).
189 Versus regular LMA, it's less robust--initialize with reasonable parameter guesses.
190 """
191 penalty = 1
192 for i,key in enumerate(roi.params):
193 if (roi.params[key]['fixed'] != True):
194 ## Substitute floating parameters in roi.params.
195 roi.params[key]['value'] = pfloat[i]
196 ## Impose penalties when outside constraints.
197 a = roi.params[key]['min']
198 b = roi.params[key]['max']
199 x = roi.params[key]['value']
200 if (x<a and a==0): penalty = penalty * (1 + (a-x))
201 if (x<a and a!=0): penalty = penalty * (1 + (a-x)/a)
202 if (x>b and b==0): penalty = penalty * (1 + (x-b))
203 if (x>b and b!=0): penalty = penalty * (1 + (x-b)/b)
204 print "%s %f" % (key, x),
205 print "penalty %0.2f" % (penalty)
206 errval = roi.model(roi.time, roi.value, roi.params) - roi.value
207 errval = errval * penalty
208 return errval
209 ## End of cerf() function.
210 '''
211
212 ################################# End of module #################################