ViewVC Help
View File | Revision Log | Show Annotations | View Changeset | Root Listing
root/osprai/osprai/trunk/fit_module.py
(Generate patch)
# Line 2 | Line 2
2   fit: Curve-fitting module for SPRI data.
3   Christopher Lausted, Institute for Systems Biology,
4   OSPRAI developers
5 < Last modified on 100422 (yymmdd)
5 > Last modified on 100425 (yymmdd)
6  
7   Example:
8   #import fit_module as fit
# Line 12 | Line 12
12   #success = lma(ba1.roi[0])
13   #for i in ba1.roi: lma(i)
14   """
15 < __version__ = "100422"
15 > __version__ = "100425"
16  
17  
18   ## Import libraries
# Line 20 | Line 20
20   import numpy as np
21   from scipy.optimize import leastsq
22   from copy import deepcopy
23 + from numpy import log, exp, tanh, arctanh
24 + #from math import exp
25  
26  
27   def lma(roi):
# Line 28 | Line 30
30      This function takes a single ba_class RegionOfInterest object (roi).
31      It modifies roi.params in place.  It returns the number of iterations completed.
32      """
31    time = roi.time
32    data = roi.value
33      params = roi.params
34      checkparams(params)
35      ## Create list initial values for the floating parameters.
36      p0 = [dict['value'] for key,dict in params.iteritems() if (dict['fixed'] != True)]
37      ## Error function.
38 <    erf = lambda p: (model_interface(roi, time, data, p) - data)
38 >    erf = lambda p: (model_interface(roi, p) - roi.value)
39      ## Fit.
40 <    #p1, success = leastsq(erf, p0, args=(), maxfev=10000)
41 <    p1, success = leastsq(erf, p0, maxfev=10000)
40 >    #p1, success = leastsq(erf, p0, args=(), maxfev=999)
41 >    p1, success = leastsq(erf, p0, maxfev=500)
42      return success
43      ## End of constlma() function.
44      
45      
46 < def model_interface(roi, time, data, pfloat):
46 > def model_interface(roi, pfloat):
47      """
48      Use the model assigned to this roi to simulation data vs time.  
49      Provide a list of values for the floating parameters (pfloat).
# Line 53 | Line 53
53      for i,key in enumerate(roi.params):
54          if (roi.params[key]['fixed'] != True):
55              roi.params[key]['value'] = pfloat[i]
56 <    return roi.model(time, data, roi.params)
56 >    return roi.model(roi.time, roi.value, roi.params)
57      
58      
59   def checkparams(params):
60      """
61      Check that each dictionary in the params dictionary contains the four
62 <    keys (value, min, max, fixed) and add them if necessary. Default is 'fixed':True.  
62 >    keys (value, min, max, fixed) and add them if necessary. Default is 'fixed':True.
63 >    Also check that when parameters float, min<value<max.    
64      """
65 +    flag = False
66      for key,dict in params.iteritems():
67 +        ## First check that keys are there.
68          if ('value' not in dict.keys()): dict['value'] = 0
69          if ('min' not in dict.keys()): dict['min'] = dict['value']
70          if ('max' not in dict.keys()): dict['max'] = dict['value']
71          if ('fixed' not in dict.keys()): dict['fixed'] = True
72 +        ## Check that when parameters float, min<value<max.
73 +        if (dict['fixed'] == False):
74 +            if (dict['min'] >= dict['max']):
75 +                ## Set max to be 1+min.
76 +                dict['max'] = float(dict['min']) + 1.0
77 +                flag = True
78 +            if (dict['value'] <= dict['min']) or (dict['value'] >= dict['max']):
79 +                ## Set initial value halfway between min and max.
80 +                dict['value'] = 0.5 * (dict['min'] + dict['max'])
81 +                flag = True
82 +    if (flag == True):
83 +        print "Parameter min/max errors: Modifications were made automatically."
84      return
85      ## End of checkparams() function.
86  
87  
88 + #### The code below is an alternative, non-working strategy. ####
89 + def transform(a, b, x0):
90 +    """
91 +    A transformation to convert x from range (-inf,+inf) to (a,b).
92 +    This is used to help constrain the outputs of the Levenberg-Marquart algorithm.
93 +    """
94 +    ## Considered transform:  x^2 for (0,+inf)
95 +    ## Considered transform:  a*tanh(x) for (-a,+a)
96 +    ## Considered transform:  a+((b-a)/(1+exp(-x)) for (a,b) but it has problems.
97 +    #x = max(-x0, -709)  ## Function math.exp(>709) overflows.
98 +    #x1 = a + ( (b-a) / (1+exp(-x) ))
99 +    ## Try transform: y = ((b-a)*tanh(x)+b+a)/2
100 +    x1 = ( (b-a)*tanh(x0) + b + a ) / 2
101 +    return x1
102 +    ## End of transform() function.
103 +    
104 +    
105 + def itransform(a, b, x1):
106 +    """
107 +    The (inverse) transformation to convert x from range(a,b) to (-inf,+inf).
108 +    This is used to help constrain the outputs of the Levenberg-Marquart algorithm.
109 +    """
110 +    #x0 = -log( ((b-a)/(x1-a)) - 1 )
111 +    x0 = arctanh( (2*x1-b-a) / (b-a) )
112 +    return x0
113 +    ## End of itransform() function.
114 +
115 +
116 + ## Warning:  The next two functions are not yet working.
117 + def clma(roi):
118 +    """
119 +    Constrained Levenberg-Marquart Algorithm fitting based on SciPy.
120 +    This function takes a single ba_class RegionOfInterest object (roi).
121 +    It modifies roi.params in place.  It returns the number of iterations completed.
122 +    """
123 +    params = roi.params
124 +    checkparams(params)
125 +    
126 +    ## Create list of initial values for the floating parameters.
127 +    p0 = [dict['value'] for key,dict in params.iteritems() if (dict['fixed'] != True)]
128 +    ## Adjust list based on min/max constraints.
129 +    for i,key in enumerate(params):
130 +        if (params[key]['fixed'] != True):
131 +            a, b = params[key]['min'], params[key]['max']
132 +            p0[i] = itransform(a, b, float(p0[i]))
133 +            # Temp
134 +            print key, params[key]['value'], p0[i]
135 +    
136 +    ## Fit.
137 +    p1, success = leastsq(cerf, p0, args=(roi), maxfev=999)
138 +    return success
139 +    ## End of constrained lma() function.
140 +    
141 +    
142 + def cerf(pfloat, roi):
143 +    """
144 +    Error estimation function for use with constrained LMA.
145 +    Use the model assigned to this roi to simulation data vs time.  
146 +    Provide a list of values for the floating parameters (pfloat).
147 +    This list is adjusted based on min/max constraints in the parameter dictionary.
148 +    This function will write the values to the parameter dictionary (roi.params).
149 +    This allows the LMA fitter to adjust the parameters in place.
150 +    """
151 +    params = roi.params
152 +    for i,key in enumerate(params):
153 +        if (params[key]['fixed'] != True):
154 +            a, b = params[key]['min'], params[key]['max']
155 +            params[key]['value'] = transform(a, b, pfloat[i])
156 +            print "%s %f" % (key, params[key]['value']), # Temp.
157 +    erf = (roi.model(roi.time, roi.value, params) - roi.value)
158 +    print "sum %0.1f" % np.sum(erf) # Temp.
159 +    return erf
160 +    ## End of constrained lma erf() function.
161 +    
162 +    
163 + '''    
164 + #### The code below is an alternative strategy that is not very robust. ####
165 + def clma_old(roi):
166 +    """
167 +    Constrained Levenberg-Marquart fitting based on SciPy.
168 +    This function takes a single ba_class RegionOfInterest object (roi).
169 +    It modifies roi.params in place.  It returns the number of iterations completed.
170 +    The error estimate is increased when a parameter moves outside the given bounds.
171 +    """
172 +    params = roi.params
173 +    checkparams(params)
174 +    ## Create list initial values for the floating parameters.
175 +    p0 = [dict['value'] for key,dict in params.iteritems() if (dict['fixed'] != True)]
176 +    ## Fit.
177 +    p1, success = leastsq(cerf_old, p0, args=(roi), maxfev=10000)
178 +    return success
179 +    ## End of clma() function.
180 +    
181 +    
182 + def cerf_old(pfloat, roi):
183 +    """
184 +    Constraining Error Function.
185 +    Use the model assigned to this roi to simulation data vs time and calculate the error.  
186 +    This allows the Constrained LMA fitter to adjust the parameters in place.
187 +    Provide a list of values for the floating parameters (pfloat).
188 +    The penalty factor is 1*(distance outside bounds).
189 +    Versus regular LMA, it's less robust--initialize with reasonable parameter guesses.
190 +    """
191 +    penalty = 1
192 +    for i,key in enumerate(roi.params):
193 +        if (roi.params[key]['fixed'] != True):
194 +            ## Substitute floating parameters in roi.params.
195 +            roi.params[key]['value'] = pfloat[i]
196 +            ## Impose penalties when outside constraints.
197 +            a = roi.params[key]['min']
198 +            b = roi.params[key]['max']
199 +            x = roi.params[key]['value']
200 +            if (x<a and a==0): penalty = penalty * (1 + (a-x))
201 +            if (x<a and a!=0): penalty = penalty * (1 + (a-x)/a)
202 +            if (x>b and b==0): penalty = penalty * (1 + (x-b))
203 +            if (x>b and b!=0): penalty = penalty * (1 + (x-b)/b)
204 +            print "%s %f" % (key, x),
205 +    print "penalty %0.2f" % (penalty)
206 +    errval = roi.model(roi.time, roi.value, roi.params) - roi.value
207 +    errval = errval * penalty
208 +    return errval
209 +    ## End of cerf() function.
210 + '''
211 +
212   ################################# End of module #################################

Diff Legend

Removed lines
+ Added lines
< Changed lines
> Changed lines