Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2Optimizer with cross validation score 

3""" 

4 

5import numpy as np 

6from sklearn.model_selection import KFold, ShuffleSplit 

7from typing import Any, Dict, Tuple 

8from .base_optimizer import BaseOptimizer 

9from .optimizer import Optimizer 

10from .fit_methods import fit 

11from .tools import ScatterData 

12 

13 

14validation_methods = { 

15 'k-fold': KFold, 

16 'shuffle-split': ShuffleSplit, 

17} 

18 

19 

20class CrossValidationEstimator(BaseOptimizer): 

21 """ 

22 This class provides an optimizer with cross validation for solving the 

23 linear :math:`\\boldsymbol{A}\\boldsymbol{x} = \\boldsymbol{y}` problem. 

24 Cross-validation (CV) scores are calculated by splitting the 

25 available reference data in multiple different ways. It also produces 

26 the finalized model (using the full input data) for which the CV score 

27 is an estimation of its performance. 

28 

29 Warning 

30 ------- 

31 Repeatedly setting up a CrossValidationEstimator and training 

32 *without* changing the seed for the random number generator will yield 

33 identical or correlated results, to avoid this please specify a different 

34 seed when setting up multiple CrossValidationEstimator instances. 

35 

36 Parameters 

37 ---------- 

38 fit_data : tuple(numpy.ndarray, numpy.ndarray) 

39 the first element of the tuple represents the fit matrix `A` 

40 (`N, M` array) while the second element represents the vector 

41 of target values `y` (`N` array); here `N` (=rows of `A`, 

42 elements of `y`) equals the number of target values and `M` 

43 (=columns of `A`) equals the number of parameters 

44 fit_method : str 

45 method to be used for training; possible choice are 

46 "least-squares", "lasso", "elasticnet", "bayesian-ridge", "ardr", 

47 "rfe", "split-bregman" 

48 standardize : bool 

49 if True the fit matrix and target values are standardized before fitting, 

50 meaning columns in the fit matrix and th target values are rescaled to 

51 have a standard deviation of 1.0. 

52 validation_method : str 

53 method to use for cross-validation; possible choices are 

54 "shuffle-split", "k-fold" 

55 n_splits : int 

56 number of times the fit data set will be split for the cross-validation 

57 check_condition : bool 

58 if True the condition number will be checked 

59 (this can be sligthly more time consuming for larger 

60 matrices) 

61 seed : int 

62 seed for pseudo random number generator 

63 

64 Attributes 

65 ---------- 

66 train_scatter_data : ScatterData 

67 contains target and predicted values from each individual 

68 traininig set in the cross-validation split; 

69 :class:`ScatterData` is a namedtuple. 

70 validation_scatter_data : ScatterData 

71 contains target and predicted values from each individual 

72 validation set in the cross-validation split; 

73 :class:`ScatterData` is a namedtuple. 

74 

75 """ 

76 

77 def __init__(self, 

78 fit_data: Tuple[np.ndarray, np.ndarray], 

79 fit_method: str = 'least-squares', 

80 standardize: bool = True, 

81 validation_method: str = 'k-fold', 

82 n_splits: int = 10, 

83 check_condition: bool = True, 

84 seed: int = 42, 

85 **kwargs) -> None: 

86 

87 super().__init__(fit_data, fit_method, standardize, check_condition, seed) 

88 

89 if validation_method not in validation_methods.keys(): 

90 msg = ['Validation method not available'] 

91 msg += ['Please choose one of the following:'] 

92 for key in validation_methods: 

93 msg += [' * ' + key] 

94 raise ValueError('\n'.join(msg)) 

95 self._validation_method = validation_method 

96 self._n_splits = n_splits 

97 self._set_kwargs(kwargs) 

98 

99 # data set splitting object 

100 self._splitter = validation_methods[validation_method]( 

101 n_splits=self.n_splits, random_state=seed, 

102 **self._split_kwargs) 

103 

104 self.train_scatter_data = None 

105 self.validation_scatter_data = None 

106 

107 self._parameters_splits = None 

108 self._rmse_train_splits = None 

109 self._rmse_valid_splits = None 

110 self._rmse_train_final = None 

111 

112 def train(self) -> None: 

113 """ Constructs the final model using all input data available. """ 

114 self._fit_results = fit(self._A, self._y, self.fit_method, 

115 self.standardize, self._check_condition, 

116 **self._fit_kwargs) 

117 self._rmse_train_final = self.compute_rmse(self._A, self._y) 

118 

119 def validate(self) -> None: 

120 """ Runs validation. """ 

121 train_target, train_predicted = [], [] 

122 valid_target, valid_predicted = [], [] 

123 rmse_train_splits, rmse_valid_splits = [], [] 

124 parameters_splits = [] 

125 for train_set, test_set in self._splitter.split(self._A): 

126 opt = Optimizer((self._A, self._y), self.fit_method, 

127 standardize=self.standardize, 

128 train_set=train_set, 

129 test_set=test_set, 

130 check_condition=self._check_condition, 

131 **self._fit_kwargs) 

132 opt.train() 

133 

134 parameters_splits.append(opt.parameters) 

135 rmse_train_splits.append(opt.rmse_train) 

136 rmse_valid_splits.append(opt.rmse_test) 

137 train_target.extend(opt.train_scatter_data.target) 

138 train_predicted.extend(opt.train_scatter_data.predicted) 

139 valid_target.extend(opt.test_scatter_data.target) 

140 valid_predicted.extend(opt.test_scatter_data.predicted) 

141 

142 self._parameters_splits = np.array(parameters_splits) 

143 self._rmse_train_splits = np.array(rmse_train_splits) 

144 self._rmse_valid_splits = np.array(rmse_valid_splits) 

145 self.train_scatter_data = ScatterData( 

146 target=np.array(train_target), predicted=np.array(train_predicted)) 

147 self.validation_scatter_data = ScatterData( 

148 target=np.array(valid_target), predicted=np.array(valid_predicted)) 

149 

150 def _set_kwargs(self, kwargs: dict) -> None: 

151 """ 

152 Sets up fit_kwargs and split_kwargs. 

153 Different split methods need different keywords. 

154 """ 

155 self._fit_kwargs = {} 

156 self._split_kwargs = {} 

157 

158 if self.validation_method == 'k-fold': 

159 self._split_kwargs['shuffle'] = True # default True 

160 for key, val in kwargs.items(): 

161 if key in ['shuffle']: 161 ↛ 162line 161 didn't jump to line 162, because the condition on line 161 was never true

162 self._split_kwargs[key] = val 

163 else: 

164 self._fit_kwargs[key] = val 

165 elif self.validation_method == 'shuffle-split': 165 ↛ exitline 165 didn't return from function '_set_kwargs', because the condition on line 165 was never false

166 for key, val in kwargs.items(): 

167 if key in ['test_size', 'train_size']: 

168 self._split_kwargs[key] = val 

169 else: 

170 self._fit_kwargs[key] = val 

171 

172 @property 

173 def summary(self) -> Dict[str, Any]: 

174 """ comprehensive information about the optimizer """ 

175 info = super().summary 

176 

177 # Add class specific data 

178 info['validation_method'] = self.validation_method 

179 info['n_splits'] = self.n_splits 

180 info['rmse_train_final'] = self.rmse_train_final 

181 info['rmse_train'] = self.rmse_train 

182 info['rmse_train_splits'] = self.rmse_train_splits 

183 info['rmse_validation'] = self.rmse_validation 

184 info['rmse_validation_splits'] = self.rmse_validation_splits 

185 info['train_scatter_data'] = self.train_scatter_data 

186 info['validation_scatter_data'] = self.validation_scatter_data 

187 

188 # add kwargs used for fitting and splitting 

189 info = {**info, **self._fit_kwargs, **self._split_kwargs} 

190 return info 

191 

192 def __repr__(self) -> str: 

193 kwargs = dict() 

194 kwargs['fit_method'] = self.fit_method 

195 kwargs['validation_method'] = self.validation_method 

196 kwargs['n_splits'] = self.n_splits 

197 kwargs['seed'] = self.seed 

198 kwargs = {**kwargs, **self._fit_kwargs, **self._split_kwargs} 

199 return 'CrossValidationEstimator((A, y), {})'.format( 

200 ', '.join('{}={}'.format(*kwarg) for kwarg in kwargs.items())) 

201 

202 @property 

203 def validation_method(self) -> str: 

204 """ validation method name """ 

205 return self._validation_method 

206 

207 @property 

208 def n_splits(self) -> int: 

209 """ number of splits (folds) used for cross-validation """ 

210 return self._n_splits 

211 

212 @property 

213 def parameters_splits(self) -> np.ndarray: 

214 """ all parameters obtained during cross-validation """ 

215 return self._parameters_splits 

216 

217 @property 

218 def n_nonzero_parameters_splits(self) -> np.ndarray: 

219 """ number of non-zero parameters for each split """ 

220 if self.parameters_splits is None: 

221 return None 

222 else: 

223 return np.array([np.count_nonzero(p) for p in self.parameters_splits]) 

224 

225 @property 

226 def rmse_train_final(self) -> float: 

227 """ 

228 root mean squared error when using the full set of input data 

229 """ 

230 return self._rmse_train_final 

231 

232 @property 

233 def rmse_train(self) -> float: 

234 """ 

235 average root mean squared training error obtained during 

236 cross-validation 

237 """ 

238 if self._rmse_train_splits is None: 

239 return None 

240 return np.sqrt(np.mean(self._rmse_train_splits**2)) 

241 

242 @property 

243 def rmse_train_splits(self) -> np.ndarray: 

244 """ 

245 root mean squared training errors obtained during 

246 cross-validation 

247 """ 

248 return self._rmse_train_splits 

249 

250 @property 

251 def rmse_validation(self) -> float: 

252 """ average root mean squared cross-validation error """ 

253 if self._rmse_valid_splits is None: 

254 return None 

255 return np.sqrt(np.mean(self._rmse_valid_splits**2)) 

256 

257 @property 

258 def rmse_validation_splits(self) -> np.ndarray: 

259 """ 

260 root mean squared validation errors obtained during 

261 cross-validation 

262 """ 

263 return self._rmse_valid_splits