55import pandas as pd
66import scipy .stats as stats
77import statsmodels .api as sm
8+ import statsmodels .formula .api as smf
9+ from patsy import dmatrices
810
911from ..exceptions import PlotnineError , PlotnineWarning
1012from ..utils import get_valid_kwargs
@@ -47,17 +49,25 @@ def lm(data, xseq, **params):
4749 """
4850 Fit OLS / WLS if data has weight
4951 """
52+ if params ['formula' ]:
53+ return lm_formula (data , xseq , ** params )
54+
5055 X = sm .add_constant (data ['x' ])
5156 Xseq = sm .add_constant (xseq )
57+ weights = data .get ('weights' , None )
5258
53- if 'weight' in data :
54- init_kwargs , fit_kwargs = separate_method_kwargs (
55- params ['method_args' ], sm .WLS , sm .WLS .fit )
56- model = sm .WLS (data ['y' ], X , weights = data ['weight' ], ** init_kwargs )
57- else :
59+ if weights is None :
5860 init_kwargs , fit_kwargs = separate_method_kwargs (
5961 params ['method_args' ], sm .OLS , sm .OLS .fit )
6062 model = sm .OLS (data ['y' ], X , ** init_kwargs )
63+ else :
64+ if np .any (weights < 0 ):
65+ raise ValueError (
66+ "All weights must be greater than zero."
67+ )
68+ init_kwargs , fit_kwargs = separate_method_kwargs (
69+ params ['method_args' ], sm .WLS , sm .WLS .fit )
70+ model = sm .WLS (data ['y' ], X , weights = data ['weight' ], ** init_kwargs )
6171
6272 results = model .fit (** fit_kwargs )
6373 data = pd .DataFrame ({'x' : xseq })
@@ -74,10 +84,60 @@ def lm(data, xseq, **params):
7484 return data
7585
7686
87+ def lm_formula (data , xseq , ** params ):
88+ """
89+ Fit OLS / WLS using a formula
90+ """
91+ formula = params ['formula' ]
92+ eval_env = params ['enviroment' ]
93+ weights = data .get ('weight' , None )
94+
95+ if weights is None :
96+ init_kwargs , fit_kwargs = separate_method_kwargs (
97+ params ['method_args' ], sm .OLS , sm .OLS .fit )
98+ model = smf .ols (
99+ formula ,
100+ data ,
101+ eval_env = eval_env ,
102+ ** init_kwargs
103+ )
104+ else :
105+ if np .any (weights < 0 ):
106+ raise ValueError (
107+ "All weights must be greater than zero."
108+ )
109+ init_kwargs , fit_kwargs = separate_method_kwargs (
110+ params ['method_args' ], sm .OLS , sm .OLS .fit )
111+ model = smf .wls (
112+ formula ,
113+ data ,
114+ weights = weights ,
115+ eval_env = eval_env ,
116+ ** init_kwargs
117+ )
118+
119+ results = model .fit (** fit_kwargs )
120+ data = pd .DataFrame ({'x' : xseq })
121+ data ['y' ] = results .predict (data )
122+
123+ if params ['se' ]:
124+ _ , predictors = dmatrices (formula , data , eval_env = eval_env )
125+ alpha = 1 - params ['level' ]
126+ prstd , iv_l , iv_u = wls_prediction_std (
127+ results , predictors , alpha = alpha )
128+ data ['se' ] = prstd
129+ data ['ymin' ] = iv_l
130+ data ['ymax' ] = iv_u
131+ return data
132+
133+
77134def rlm (data , xseq , ** params ):
78135 """
79136 Fit RLM
80137 """
138+ if params ['formula' ]:
139+ return rlm_formula (data , xseq , ** params )
140+
81141 X = sm .add_constant (data ['x' ])
82142 Xseq = sm .add_constant (xseq )
83143
@@ -96,10 +156,38 @@ def rlm(data, xseq, **params):
96156 return data
97157
98158
159+ def rlm_formula (data , xseq , ** params ):
160+ """
161+ Fit RLM using a formula
162+ """
163+ eval_env = params ['enviroment' ]
164+ formula = params ['formula' ]
165+ init_kwargs , fit_kwargs = separate_method_kwargs (
166+ params ['method_args' ], sm .RLM , sm .RLM .fit )
167+ model = smf .rlm (
168+ formula ,
169+ data ,
170+ eval_env = eval_env ,
171+ ** init_kwargs
172+ )
173+ results = model .fit (** fit_kwargs )
174+ data = pd .DataFrame ({'x' : xseq })
175+ data ['y' ] = results .predict (data )
176+
177+ if params ['se' ]:
178+ warnings .warn ("Confidence intervals are not yet implemented"
179+ "for RLM smoothing." , PlotnineWarning )
180+
181+ return data
182+
183+
99184def gls (data , xseq , ** params ):
100185 """
101186 Fit GLS
102187 """
188+ if params ['formula' ]:
189+ return gls_formula (data , xseq , ** params )
190+
103191 X = sm .add_constant (data ['x' ])
104192 Xseq = sm .add_constant (xseq )
105193
@@ -122,10 +210,42 @@ def gls(data, xseq, **params):
122210 return data
123211
124212
213+ def gls_formula (data , xseq , ** params ):
214+ """
215+ Fit GLL using a formula
216+ """
217+ eval_env = params ['enviroment' ]
218+ formula = params ['formula' ]
219+ init_kwargs , fit_kwargs = separate_method_kwargs (
220+ params ['method_args' ], sm .GLS , sm .GLS .fit )
221+ model = smf .gls (
222+ formula ,
223+ data ,
224+ eval_env = eval_env ,
225+ ** init_kwargs
226+ )
227+ results = model .fit (** fit_kwargs )
228+ data = pd .DataFrame ({'x' : xseq })
229+ data ['y' ] = results .predict (data )
230+
231+ if params ['se' ]:
232+ _ , predictors = dmatrices (formula , data , eval_env = eval_env )
233+ alpha = 1 - params ['level' ]
234+ prstd , iv_l , iv_u = wls_prediction_std (
235+ results , predictors , alpha = alpha )
236+ data ['se' ] = prstd
237+ data ['ymin' ] = iv_l
238+ data ['ymax' ] = iv_u
239+ return data
240+
241+
125242def glm (data , xseq , ** params ):
126243 """
127244 Fit GLM
128245 """
246+ if params ['formula' ]:
247+ return glm_formula (data , xseq , ** params )
248+
129249 X = sm .add_constant (data ['x' ])
130250 Xseq = sm .add_constant (xseq )
131251
@@ -146,6 +266,29 @@ def glm(data, xseq, **params):
146266 return data
147267
148268
269+ def glm_formula (data , xseq , ** params ):
270+ eval_env = params ['enviroment' ]
271+ init_kwargs , fit_kwargs = separate_method_kwargs (
272+ params ['method_args' ], sm .GLM , sm .GLM .fit )
273+ model = smf .glm (
274+ params ['formula' ],
275+ data ,
276+ eval_env = eval_env ,
277+ ** init_kwargs
278+ )
279+ results = model .fit (** fit_kwargs )
280+ data = pd .DataFrame ({'x' : xseq })
281+ data ['y' ] = results .predict (data )
282+
283+ if params ['se' ]:
284+ df = pd .DataFrame ({'x' : xseq })
285+ prediction = results .get_prediction (df )
286+ ci = prediction .conf_int (1 - params ['level' ])
287+ data ['ymin' ] = ci [:, 0 ]
288+ data ['ymax' ] = ci [:, 1 ]
289+ return data
290+
291+
149292def lowess (data , xseq , ** params ):
150293 for k in ('is_sorted' , 'return_sorted' ):
151294 with suppress (KeyError ):
0 commit comments