ViewVC Help
View File | Revision Log | Show Annotations | View Changeset | Root Listing
root/osprai/osprai/trunk/fit_module.py
Revision: 22
Committed: Tue Apr 27 07:17:58 2010 UTC (9 years, 4 months ago) by clausted
File size: 8386 byte(s)
Log Message:
The new module osprai_one provides access to the important functions in all of the other modules.  Documentation has been started with osprai_one.html.
Line File contents
1 """
2 fit: Curve-fitting module for SPRI data.
3 Christopher Lausted, Institute for Systems Biology,
4 OSPRAI developers
5 Last modified on 100425 (yymmdd)
6
7 Example:
8 #import fit_module as fit
9 #import mdl_module as mdl
10 #ba1.roi[0].model = mdl.drift
11 #ba1.roi[0].params = {'rate': {'value':1, 'min':-100.0, 'max':100.0, 'fixed':False} }
12 #success = lma(ba1.roi[0])
13 #for i in ba1.roi: lma(i)
14 """
15 __version__ = "100425"
16
17
18 ## Import libraries
19 import ba_class as ba
20 import numpy as np
21 from scipy.optimize import leastsq
22 from copy import deepcopy
23 from numpy import log, exp, tanh, arctanh
24 #from math import exp
25
26
27 def lma(roi):
28 """
29 Normal Levenberg-Marquart fitting based on SciPy.
30 This function takes a single ba_class RegionOfInterest object (roi).
31 It modifies roi.params in place. It returns the number of iterations completed.
32 """
33 params = roi.params
34 checkparams(params)
35 ## Create list initial values for the floating parameters.
36 p0 = [dict['value'] for key,dict in params.iteritems() if (dict['fixed'] != True)]
37 ## Error function.
38 erf = lambda p: (model_interface(roi, p) - roi.value)
39 ## Fit.
40 #p1, success = leastsq(erf, p0, args=(), maxfev=999)
41 p1, success = leastsq(erf, p0, maxfev=500)
42 return success
43 ## End of constlma() function.
44
45
46 def model_interface(roi, pfloat):
47 """
48 Use the model assigned to this roi to simulation data vs time.
49 Provide a list of values for the floating parameters (pfloat).
50 This function will write the values to the parameter dictionary (roi.params).
51 This allows the LMA fitter to adjust the parameters in place.
52 """
53 for i,key in enumerate(roi.params):
54 if (roi.params[key]['fixed'] != True):
55 roi.params[key]['value'] = pfloat[i]
56 return roi.model(roi.time, roi.value, roi.params)
57
58
59 def checkparams(params):
60 """
61 Check that each dictionary in the params dictionary contains the four
62 keys (value, min, max, fixed) and add them if necessary. Default is 'fixed':True.
63 Also check that when parameters float, min<value<max.
64 """
65 flag = False
66 for key,dict in params.iteritems():
67 ## First check that keys are there.
68 if ('value' not in dict.keys()): dict['value'] = 0
69 if ('min' not in dict.keys()): dict['min'] = dict['value']
70 if ('max' not in dict.keys()): dict['max'] = dict['value']
71 if ('fixed' not in dict.keys()): dict['fixed'] = True
72 ## Check that when parameters float, min<value<max.
73 if (dict['fixed'] == False):
74 if (dict['min'] >= dict['max']):
75 ## Set max to be 1+min.
76 dict['max'] = float(dict['min']) + 1.0
77 flag = True
78 if (dict['value'] <= dict['min']) or (dict['value'] >= dict['max']):
79 ## Set initial value halfway between min and max.
80 dict['value'] = 0.5 * (dict['min'] + dict['max'])
81 flag = True
82 if (flag == True):
83 print "Parameter min/max errors: Modifications were made automatically."
84 return
85 ## End of checkparams() function.
86
87
88 #### The code below is an alternative strategy for constrained parameters ####
89 def transform(a, b, x0):
90 """
91 A transformation to convert x from range (-inf,+inf) to (a,b).
92 This is used to help constrain the outputs of the Levenberg-Marquart algorithm.
93 """
94 ## Considered transform: x^2 for (0,+inf)
95 ## Considered transform: a*tanh(x) for (-a,+a)
96 ## Considered transform: a+((b-a)/(1+exp(-x)) for (a,b) but it has problems.
97 #x = max(-x0, -709) ## Function math.exp(>709) overflows.
98 #x1 = a + ( (b-a) / (1+exp(-x) ))
99 ## Try transform: y = ((b-a)*tanh(x)+b+a)/2
100 x1 = ( (b-a)*tanh(x0) + b + a ) / 2
101 return x1
102 ## End of transform() function.
103
104
105 def itransform(a, b, x1):
106 """
107 The (inverse) transformation to convert x from range(a,b) to (-inf,+inf).
108 This is used to help constrain the outputs of the Levenberg-Marquart algorithm.
109 """
110 #x0 = -log( ((b-a)/(x1-a)) - 1 )
111 x0 = arctanh( (2*x1-b-a) / (b-a) )
112 return x0
113 ## End of itransform() function.
114
115
116 def clma(roi):
117 """
118 Constrained Levenberg-Marquart Algorithm fitting based on SciPy.
119 This function takes a single ba_class RegionOfInterest object (roi).
120 It modifies roi.params in place. It returns the number of iterations completed.
121 """
122 params = roi.params
123 checkparams(params)
124
125 ## Create list of initial values for the floating parameters.
126 p0 = [dict['value'] for key,dict in params.iteritems() if (dict['fixed'] != True)]
127 ## Adjust list based on min/max constraints.
128 for i,key in enumerate(params):
129 if (params[key]['fixed'] != True):
130 a, b = params[key]['min'], params[key]['max']
131 p0[i] = itransform(a, b, float(p0[i]))
132 # Temp
133 print key, params[key]['value'], p0[i]
134
135 ## Fit.
136 p1, success = leastsq(cerf, p0, args=(roi), maxfev=999)
137 return success
138 ## End of constrained lma() function.
139
140
141 def cerf(pfloat, roi):
142 """
143 Error estimation function for use with constrained LMA.
144 Use the model assigned to this roi to simulation data vs time.
145 Provide a list of values for the floating parameters (pfloat).
146 This list is adjusted based on min/max constraints in the parameter dictionary.
147 This function will write the values to the parameter dictionary (roi.params).
148 This allows the LMA fitter to adjust the parameters in place.
149 """
150 params = roi.params
151 for i,key in enumerate(params):
152 if (params[key]['fixed'] != True):
153 a, b = params[key]['min'], params[key]['max']
154 params[key]['value'] = transform(a, b, pfloat[i])
155 print "%s %f" % (key, params[key]['value']), # Temp.
156 erf = (roi.model(roi.time, roi.value, params) - roi.value)
157 print "sum %0.1f" % np.sum(erf) # Temp.
158 return erf
159 ## End of constrained lma erf() function.
160
161
162 '''
163 #### The code below is an alternative strategy that is not very robust. ####
164 def clma_old(roi):
165 """
166 Constrained Levenberg-Marquart fitting based on SciPy.
167 This function takes a single ba_class RegionOfInterest object (roi).
168 It modifies roi.params in place. It returns the number of iterations completed.
169 The error estimate is increased when a parameter moves outside the given bounds.
170 """
171 params = roi.params
172 checkparams(params)
173 ## Create list initial values for the floating parameters.
174 p0 = [dict['value'] for key,dict in params.iteritems() if (dict['fixed'] != True)]
175 ## Fit.
176 p1, success = leastsq(cerf_old, p0, args=(roi), maxfev=10000)
177 return success
178 ## End of clma() function.
179
180
181 def cerf_old(pfloat, roi):
182 """
183 Constraining Error Function.
184 Use the model assigned to this roi to simulation data vs time and calculate the error.
185 This allows the Constrained LMA fitter to adjust the parameters in place.
186 Provide a list of values for the floating parameters (pfloat).
187 The penalty factor is 1*(distance outside bounds).
188 Versus regular LMA, it's less robust--initialize with reasonable parameter guesses.
189 """
190 penalty = 1
191 for i,key in enumerate(roi.params):
192 if (roi.params[key]['fixed'] != True):
193 ## Substitute floating parameters in roi.params.
194 roi.params[key]['value'] = pfloat[i]
195 ## Impose penalties when outside constraints.
196 a = roi.params[key]['min']
197 b = roi.params[key]['max']
198 x = roi.params[key]['value']
199 if (x<a and a==0): penalty = penalty * (1 + (a-x))
200 if (x<a and a!=0): penalty = penalty * (1 + (a-x)/a)
201 if (x>b and b==0): penalty = penalty * (1 + (x-b))
202 if (x>b and b!=0): penalty = penalty * (1 + (x-b)/b)
203 print "%s %f" % (key, x),
204 print "penalty %0.2f" % (penalty)
205 errval = roi.model(roi.time, roi.value, roi.params) - roi.value
206 errval = errval * penalty
207 return errval
208 ## End of cerf() function.
209 '''
210
211 ################################# End of module #################################