ViewVC Help
View File | Revision Log | Show Annotations | View Changeset | Root Listing
root/osprai/osprai/trunk/fit_module.py
Revision: 20
Committed: Mon Apr 26 21:59:45 2010 UTC (9 years, 6 months ago) by clausted
File size: 8431 byte(s)
Log Message:
Added clma(), a constrained LMA fitting function.  We can now set a lower and upper bound for fitting.  Be sure that the initial estimate is between, not at, the bounds.  It transforms the parameters using x1 = ((b-a)*tanh(x0)+b+a)/2 for the interval (a,b).  
Line User Rev File contents
1 clausted 19 """
2     fit: Curve-fitting module for SPRI data.
3     Christopher Lausted, Institute for Systems Biology,
4     OSPRAI developers
5 clausted 20 Last modified on 100425 (yymmdd)
6 clausted 19
7     Example:
8     #import fit_module as fit
9     #import mdl_module as mdl
10     #ba1.roi[0].model = mdl.drift
11     #ba1.roi[0].params = {'rate': {'value':1, 'min':-100.0, 'max':100.0, 'fixed':False} }
12     #success = lma(ba1.roi[0])
13     #for i in ba1.roi: lma(i)
14     """
15 clausted 20 __version__ = "100425"
16 clausted 19
17    
18     ## Import libraries
19     import ba_class as ba
20     import numpy as np
21     from scipy.optimize import leastsq
22     from copy import deepcopy
23 clausted 20 from numpy import log, exp, tanh, arctanh
24     #from math import exp
25 clausted 19
26    
27     def lma(roi):
28     """
29     Normal Levenberg-Marquart fitting based on SciPy.
30     This function takes a single ba_class RegionOfInterest object (roi).
31     It modifies roi.params in place. It returns the number of iterations completed.
32     """
33     params = roi.params
34     checkparams(params)
35     ## Create list initial values for the floating parameters.
36     p0 = [dict['value'] for key,dict in params.iteritems() if (dict['fixed'] != True)]
37     ## Error function.
38 clausted 20 erf = lambda p: (model_interface(roi, p) - roi.value)
39 clausted 19 ## Fit.
40 clausted 20 #p1, success = leastsq(erf, p0, args=(), maxfev=999)
41     p1, success = leastsq(erf, p0, maxfev=500)
42 clausted 19 return success
43     ## End of constlma() function.
44    
45    
46 clausted 20 def model_interface(roi, pfloat):
47 clausted 19 """
48     Use the model assigned to this roi to simulation data vs time.
49     Provide a list of values for the floating parameters (pfloat).
50     This function will write the values to the parameter dictionary (roi.params).
51     This allows the LMA fitter to adjust the parameters in place.
52     """
53     for i,key in enumerate(roi.params):
54     if (roi.params[key]['fixed'] != True):
55     roi.params[key]['value'] = pfloat[i]
56 clausted 20 return roi.model(roi.time, roi.value, roi.params)
57 clausted 19
58    
59     def checkparams(params):
60     """
61     Check that each dictionary in the params dictionary contains the four
62 clausted 20 keys (value, min, max, fixed) and add them if necessary. Default is 'fixed':True.
63     Also check that when parameters float, min<value<max.
64 clausted 19 """
65 clausted 20 flag = False
66 clausted 19 for key,dict in params.iteritems():
67 clausted 20 ## First check that keys are there.
68 clausted 19 if ('value' not in dict.keys()): dict['value'] = 0
69     if ('min' not in dict.keys()): dict['min'] = dict['value']
70     if ('max' not in dict.keys()): dict['max'] = dict['value']
71     if ('fixed' not in dict.keys()): dict['fixed'] = True
72 clausted 20 ## Check that when parameters float, min<value<max.
73     if (dict['fixed'] == False):
74     if (dict['min'] >= dict['max']):
75     ## Set max to be 1+min.
76     dict['max'] = float(dict['min']) + 1.0
77     flag = True
78     if (dict['value'] <= dict['min']) or (dict['value'] >= dict['max']):
79     ## Set initial value halfway between min and max.
80     dict['value'] = 0.5 * (dict['min'] + dict['max'])
81     flag = True
82     if (flag == True):
83     print "Parameter min/max errors: Modifications were made automatically."
84 clausted 19 return
85     ## End of checkparams() function.
86    
87    
88 clausted 20 #### The code below is an alternative, non-working strategy. ####
89     def transform(a, b, x0):
90     """
91     A transformation to convert x from range (-inf,+inf) to (a,b).
92     This is used to help constrain the outputs of the Levenberg-Marquart algorithm.
93     """
94     ## Considered transform: x^2 for (0,+inf)
95     ## Considered transform: a*tanh(x) for (-a,+a)
96     ## Considered transform: a+((b-a)/(1+exp(-x)) for (a,b) but it has problems.
97     #x = max(-x0, -709) ## Function math.exp(>709) overflows.
98     #x1 = a + ( (b-a) / (1+exp(-x) ))
99     ## Try transform: y = ((b-a)*tanh(x)+b+a)/2
100     x1 = ( (b-a)*tanh(x0) + b + a ) / 2
101     return x1
102     ## End of transform() function.
103    
104    
105     def itransform(a, b, x1):
106     """
107     The (inverse) transformation to convert x from range(a,b) to (-inf,+inf).
108     This is used to help constrain the outputs of the Levenberg-Marquart algorithm.
109     """
110     #x0 = -log( ((b-a)/(x1-a)) - 1 )
111     x0 = arctanh( (2*x1-b-a) / (b-a) )
112     return x0
113     ## End of itransform() function.
114    
115    
116     ## Warning: The next two functions are not yet working.
117     def clma(roi):
118     """
119     Constrained Levenberg-Marquart Algorithm fitting based on SciPy.
120     This function takes a single ba_class RegionOfInterest object (roi).
121     It modifies roi.params in place. It returns the number of iterations completed.
122     """
123     params = roi.params
124     checkparams(params)
125    
126     ## Create list of initial values for the floating parameters.
127     p0 = [dict['value'] for key,dict in params.iteritems() if (dict['fixed'] != True)]
128     ## Adjust list based on min/max constraints.
129     for i,key in enumerate(params):
130     if (params[key]['fixed'] != True):
131     a, b = params[key]['min'], params[key]['max']
132     p0[i] = itransform(a, b, float(p0[i]))
133     # Temp
134     print key, params[key]['value'], p0[i]
135    
136     ## Fit.
137     p1, success = leastsq(cerf, p0, args=(roi), maxfev=999)
138     return success
139     ## End of constrained lma() function.
140    
141    
142     def cerf(pfloat, roi):
143     """
144     Error estimation function for use with constrained LMA.
145     Use the model assigned to this roi to simulation data vs time.
146     Provide a list of values for the floating parameters (pfloat).
147     This list is adjusted based on min/max constraints in the parameter dictionary.
148     This function will write the values to the parameter dictionary (roi.params).
149     This allows the LMA fitter to adjust the parameters in place.
150     """
151     params = roi.params
152     for i,key in enumerate(params):
153     if (params[key]['fixed'] != True):
154     a, b = params[key]['min'], params[key]['max']
155     params[key]['value'] = transform(a, b, pfloat[i])
156     print "%s %f" % (key, params[key]['value']), # Temp.
157     erf = (roi.model(roi.time, roi.value, params) - roi.value)
158     print "sum %0.1f" % np.sum(erf) # Temp.
159     return erf
160     ## End of constrained lma erf() function.
161    
162    
163     '''
164     #### The code below is an alternative strategy that is not very robust. ####
165     def clma_old(roi):
166     """
167     Constrained Levenberg-Marquart fitting based on SciPy.
168     This function takes a single ba_class RegionOfInterest object (roi).
169     It modifies roi.params in place. It returns the number of iterations completed.
170     The error estimate is increased when a parameter moves outside the given bounds.
171     """
172     params = roi.params
173     checkparams(params)
174     ## Create list initial values for the floating parameters.
175     p0 = [dict['value'] for key,dict in params.iteritems() if (dict['fixed'] != True)]
176     ## Fit.
177     p1, success = leastsq(cerf_old, p0, args=(roi), maxfev=10000)
178     return success
179     ## End of clma() function.
180    
181    
182     def cerf_old(pfloat, roi):
183     """
184     Constraining Error Function.
185     Use the model assigned to this roi to simulation data vs time and calculate the error.
186     This allows the Constrained LMA fitter to adjust the parameters in place.
187     Provide a list of values for the floating parameters (pfloat).
188     The penalty factor is 1*(distance outside bounds).
189     Versus regular LMA, it's less robust--initialize with reasonable parameter guesses.
190     """
191     penalty = 1
192     for i,key in enumerate(roi.params):
193     if (roi.params[key]['fixed'] != True):
194     ## Substitute floating parameters in roi.params.
195     roi.params[key]['value'] = pfloat[i]
196     ## Impose penalties when outside constraints.
197     a = roi.params[key]['min']
198     b = roi.params[key]['max']
199     x = roi.params[key]['value']
200     if (x<a and a==0): penalty = penalty * (1 + (a-x))
201     if (x<a and a!=0): penalty = penalty * (1 + (a-x)/a)
202     if (x>b and b==0): penalty = penalty * (1 + (x-b))
203     if (x>b and b!=0): penalty = penalty * (1 + (x-b)/b)
204     print "%s %f" % (key, x),
205     print "penalty %0.2f" % (penalty)
206     errval = roi.model(roi.time, roi.value, roi.params) - roi.value
207     errval = errval * penalty
208     return errval
209     ## End of cerf() function.
210     '''
211    
212 clausted 19 ################################# End of module #################################