ViewVC Help
View File | Revision Log | Show Annotations | View Changeset | Root Listing
root/osprai/osprai/trunk/fit_module.py
Revision: 22
Committed: Tue Apr 27 07:17:58 2010 UTC (9 years, 7 months ago) by clausted
File size: 8386 byte(s)
Log Message:
The new module osprai_one provides access to the important functions in all of the other modules.  Documentation has been started with osprai_one.html.
Line User Rev File contents
1 clausted 19 """
2     fit: Curve-fitting module for SPRI data.
3     Christopher Lausted, Institute for Systems Biology,
4     OSPRAI developers
5 clausted 20 Last modified on 100425 (yymmdd)
6 clausted 19
7     Example:
8     #import fit_module as fit
9     #import mdl_module as mdl
10     #ba1.roi[0].model = mdl.drift
11     #ba1.roi[0].params = {'rate': {'value':1, 'min':-100.0, 'max':100.0, 'fixed':False} }
12     #success = lma(ba1.roi[0])
13     #for i in ba1.roi: lma(i)
14     """
15 clausted 20 __version__ = "100425"
16 clausted 19
17    
18     ## Import libraries
19     import ba_class as ba
20     import numpy as np
21     from scipy.optimize import leastsq
22     from copy import deepcopy
23 clausted 20 from numpy import log, exp, tanh, arctanh
24     #from math import exp
25 clausted 19
26    
27     def lma(roi):
28     """
29     Normal Levenberg-Marquart fitting based on SciPy.
30     This function takes a single ba_class RegionOfInterest object (roi).
31     It modifies roi.params in place. It returns the number of iterations completed.
32     """
33     params = roi.params
34     checkparams(params)
35     ## Create list initial values for the floating parameters.
36     p0 = [dict['value'] for key,dict in params.iteritems() if (dict['fixed'] != True)]
37     ## Error function.
38 clausted 20 erf = lambda p: (model_interface(roi, p) - roi.value)
39 clausted 19 ## Fit.
40 clausted 20 #p1, success = leastsq(erf, p0, args=(), maxfev=999)
41     p1, success = leastsq(erf, p0, maxfev=500)
42 clausted 19 return success
43     ## End of constlma() function.
44    
45    
46 clausted 20 def model_interface(roi, pfloat):
47 clausted 19 """
48     Use the model assigned to this roi to simulation data vs time.
49     Provide a list of values for the floating parameters (pfloat).
50     This function will write the values to the parameter dictionary (roi.params).
51     This allows the LMA fitter to adjust the parameters in place.
52     """
53     for i,key in enumerate(roi.params):
54     if (roi.params[key]['fixed'] != True):
55     roi.params[key]['value'] = pfloat[i]
56 clausted 20 return roi.model(roi.time, roi.value, roi.params)
57 clausted 19
58    
59     def checkparams(params):
60     """
61     Check that each dictionary in the params dictionary contains the four
62 clausted 20 keys (value, min, max, fixed) and add them if necessary. Default is 'fixed':True.
63     Also check that when parameters float, min<value<max.
64 clausted 19 """
65 clausted 20 flag = False
66 clausted 19 for key,dict in params.iteritems():
67 clausted 20 ## First check that keys are there.
68 clausted 19 if ('value' not in dict.keys()): dict['value'] = 0
69     if ('min' not in dict.keys()): dict['min'] = dict['value']
70     if ('max' not in dict.keys()): dict['max'] = dict['value']
71     if ('fixed' not in dict.keys()): dict['fixed'] = True
72 clausted 20 ## Check that when parameters float, min<value<max.
73     if (dict['fixed'] == False):
74     if (dict['min'] >= dict['max']):
75     ## Set max to be 1+min.
76     dict['max'] = float(dict['min']) + 1.0
77     flag = True
78     if (dict['value'] <= dict['min']) or (dict['value'] >= dict['max']):
79     ## Set initial value halfway between min and max.
80     dict['value'] = 0.5 * (dict['min'] + dict['max'])
81     flag = True
82     if (flag == True):
83     print "Parameter min/max errors: Modifications were made automatically."
84 clausted 19 return
85     ## End of checkparams() function.
86    
87    
88 clausted 22 #### The code below is an alternative strategy for constrained parameters ####
89 clausted 20 def transform(a, b, x0):
90     """
91     A transformation to convert x from range (-inf,+inf) to (a,b).
92     This is used to help constrain the outputs of the Levenberg-Marquart algorithm.
93     """
94     ## Considered transform: x^2 for (0,+inf)
95     ## Considered transform: a*tanh(x) for (-a,+a)
96     ## Considered transform: a+((b-a)/(1+exp(-x)) for (a,b) but it has problems.
97     #x = max(-x0, -709) ## Function math.exp(>709) overflows.
98     #x1 = a + ( (b-a) / (1+exp(-x) ))
99     ## Try transform: y = ((b-a)*tanh(x)+b+a)/2
100     x1 = ( (b-a)*tanh(x0) + b + a ) / 2
101     return x1
102     ## End of transform() function.
103    
104    
105     def itransform(a, b, x1):
106     """
107     The (inverse) transformation to convert x from range(a,b) to (-inf,+inf).
108     This is used to help constrain the outputs of the Levenberg-Marquart algorithm.
109     """
110     #x0 = -log( ((b-a)/(x1-a)) - 1 )
111     x0 = arctanh( (2*x1-b-a) / (b-a) )
112     return x0
113     ## End of itransform() function.
114    
115    
116     def clma(roi):
117     """
118     Constrained Levenberg-Marquart Algorithm fitting based on SciPy.
119     This function takes a single ba_class RegionOfInterest object (roi).
120     It modifies roi.params in place. It returns the number of iterations completed.
121     """
122     params = roi.params
123     checkparams(params)
124    
125     ## Create list of initial values for the floating parameters.
126     p0 = [dict['value'] for key,dict in params.iteritems() if (dict['fixed'] != True)]
127     ## Adjust list based on min/max constraints.
128     for i,key in enumerate(params):
129     if (params[key]['fixed'] != True):
130     a, b = params[key]['min'], params[key]['max']
131     p0[i] = itransform(a, b, float(p0[i]))
132     # Temp
133     print key, params[key]['value'], p0[i]
134    
135     ## Fit.
136     p1, success = leastsq(cerf, p0, args=(roi), maxfev=999)
137     return success
138     ## End of constrained lma() function.
139    
140    
141     def cerf(pfloat, roi):
142     """
143     Error estimation function for use with constrained LMA.
144     Use the model assigned to this roi to simulation data vs time.
145     Provide a list of values for the floating parameters (pfloat).
146     This list is adjusted based on min/max constraints in the parameter dictionary.
147     This function will write the values to the parameter dictionary (roi.params).
148     This allows the LMA fitter to adjust the parameters in place.
149     """
150     params = roi.params
151     for i,key in enumerate(params):
152     if (params[key]['fixed'] != True):
153     a, b = params[key]['min'], params[key]['max']
154     params[key]['value'] = transform(a, b, pfloat[i])
155     print "%s %f" % (key, params[key]['value']), # Temp.
156     erf = (roi.model(roi.time, roi.value, params) - roi.value)
157     print "sum %0.1f" % np.sum(erf) # Temp.
158     return erf
159     ## End of constrained lma erf() function.
160    
161    
162     '''
163     #### The code below is an alternative strategy that is not very robust. ####
164     def clma_old(roi):
165     """
166     Constrained Levenberg-Marquart fitting based on SciPy.
167     This function takes a single ba_class RegionOfInterest object (roi).
168     It modifies roi.params in place. It returns the number of iterations completed.
169     The error estimate is increased when a parameter moves outside the given bounds.
170     """
171     params = roi.params
172     checkparams(params)
173     ## Create list initial values for the floating parameters.
174     p0 = [dict['value'] for key,dict in params.iteritems() if (dict['fixed'] != True)]
175     ## Fit.
176     p1, success = leastsq(cerf_old, p0, args=(roi), maxfev=10000)
177     return success
178     ## End of clma() function.
179    
180    
181     def cerf_old(pfloat, roi):
182     """
183     Constraining Error Function.
184     Use the model assigned to this roi to simulation data vs time and calculate the error.
185     This allows the Constrained LMA fitter to adjust the parameters in place.
186     Provide a list of values for the floating parameters (pfloat).
187     The penalty factor is 1*(distance outside bounds).
188     Versus regular LMA, it's less robust--initialize with reasonable parameter guesses.
189     """
190     penalty = 1
191     for i,key in enumerate(roi.params):
192     if (roi.params[key]['fixed'] != True):
193     ## Substitute floating parameters in roi.params.
194     roi.params[key]['value'] = pfloat[i]
195     ## Impose penalties when outside constraints.
196     a = roi.params[key]['min']
197     b = roi.params[key]['max']
198     x = roi.params[key]['value']
199     if (x<a and a==0): penalty = penalty * (1 + (a-x))
200     if (x<a and a!=0): penalty = penalty * (1 + (a-x)/a)
201     if (x>b and b==0): penalty = penalty * (1 + (x-b))
202     if (x>b and b!=0): penalty = penalty * (1 + (x-b)/b)
203     print "%s %f" % (key, x),
204     print "penalty %0.2f" % (penalty)
205     errval = roi.model(roi.time, roi.value, roi.params) - roi.value
206     errval = errval * penalty
207     return errval
208     ## End of cerf() function.
209     '''
210    
211 clausted 19 ################################# End of module #################################