Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2This module provides the ClusterExpansion class. 

3""" 

4 

5import pandas as pd 

6import numpy as np 

7import pickle 

8import tempfile 

9import tarfile 

10import re 

11 

12from icet import ClusterSpace 

13from icet.core.structure import Structure 

14from typing import List, Union 

15from ase import Atoms 

16 

17 

18class ClusterExpansion: 

19 """Cluster expansions are obtained by combining a cluster space with a set 

20 of parameters, where the latter is commonly obtained by optimization. 

21 Instances of this class allow one to predict the property of interest for 

22 a given structure. 

23 

24 **Note:** Each element of the parameter vector corresponds to an 

25 effective cluster interaction (ECI) multiplied by the multiplicity of the 

26 underlying orbit. 

27 

28 Attributes 

29 ---------- 

30 cluster_space : icet.ClusterSpace 

31 cluster space that was used for constructing the cluster expansion 

32 parameters : np.ndarray 

33 parameter vector 

34 

35 Example 

36 ------- 

37 The following snippet illustrates the initialization and usage of 

38 a ClusterExpansion object. Here, the parameters are taken to be a list 

39 of ones. Usually, they would be obtained by training with 

40 respect to a set of reference data:: 

41 

42 >>> from ase.build import bulk 

43 >>> from icet import ClusterSpace, ClusterExpansion 

44 

45 >>> # create cluster expansion with fake parameters 

46 >>> prim = bulk('Au') 

47 >>> cs = ClusterSpace(prim, cutoffs=[7.0, 5.0], 

48 ... chemical_symbols=[['Au', 'Pd']]) 

49 >>> parameters = len(cs) * [1.0] 

50 >>> ce = ClusterExpansion(cs, parameters) 

51 

52 >>> # make prediction for supercell 

53 >>> sc = prim.repeat(3) 

54 >>> for k in [1, 4, 7]: 

55 >>> sc[k].symbol = 'Pd' 

56 >>> print(ce.predict(sc)) 

57 """ 

58 

59 def __init__(self, cluster_space: ClusterSpace, parameters: np.array, 

60 metadata: dict = None) -> None: 

61 """ 

62 Initializes a ClusterExpansion object. 

63 

64 Parameters 

65 ---------- 

66 cluster_space 

67 cluster space to be used for constructing the cluster expansion 

68 parameters 

69 parameter vector 

70 metadata : dict 

71 metadata dictionary, user-defined metadata to be stored together 

72 with cluster expansion. Will be pickled when CE is written to file. 

73 By default contains icet version, username, hostname and date. 

74 

75 Raises 

76 ------ 

77 ValueError 

78 if cluster space and parameters differ in length 

79 """ 

80 if len(cluster_space) != len(parameters): 

81 raise ValueError('cluster_space ({}) and parameters ({}) must have' 

82 ' the same length'.format(len(cluster_space), len(parameters))) 

83 self._cluster_space = cluster_space.copy() 

84 if isinstance(parameters, list): 

85 parameters = np.array(parameters) 

86 self._parameters = parameters 

87 

88 # add metadata 

89 if metadata is None: 

90 metadata = dict() 

91 self._metadata = metadata 

92 self._add_default_metadata() 

93 

94 def predict(self, structure: Union[Atoms, Structure]) -> float: 

95 """ 

96 Predicts the property of interest (e.g., the energy) for the input 

97 structure using the cluster expansion. 

98 

99 Parameters 

100 ---------- 

101 structure 

102 atomic configuration 

103 

104 Returns 

105 ------- 

106 float 

107 property value of predicted by the cluster expansion 

108 """ 

109 cluster_vector = self._cluster_space.get_cluster_vector(structure) 

110 prop = np.dot(cluster_vector, self.parameters) 

111 return prop 

112 

113 def get_cluster_space_copy(self) -> ClusterSpace: 

114 """ Gets copy of cluster space on which cluster expansion is based """ 

115 return self._cluster_space.copy() 

116 

117 def to_dataframe(self) -> pd.DataFrame: 

118 """Returns representation of the cluster expansion in the form of a 

119 DataFrame containing orbit information and effective cluster interactions 

120 (ECIs).""" 

121 rows = self._cluster_space.orbit_data 

122 for row, param in zip(rows, self.parameters): 

123 row['parameter'] = param 

124 row['eci'] = param / row['multiplicity'] 

125 return pd.DataFrame(rows) 

126 

127 @property 

128 def orders(self) -> List[int]: 

129 """ orders included in cluster expansion """ 

130 return list(range(len(self._cluster_space.cutoffs) + 2)) 

131 

132 @property 

133 def parameters(self) -> List[float]: 

134 """ parameter vector; each element of the parameter vector corresponds 

135 to an effective cluster interaction (ECI) multiplied by the 

136 multiplicity of the respective orbit """ 

137 return self._parameters 

138 

139 @property 

140 def metadata(self) -> dict: 

141 """ metadata associated with cluster expansion """ 

142 return self._metadata 

143 

144 @property 

145 def symprec(self) -> float: 

146 """ tolerance imposed when analyzing the symmetry using spglib 

147 (inherited from the underlying cluster space) """ 

148 return self._cluster_space.symprec 

149 

150 @property 

151 def position_tolerance(self) -> float: 

152 """ tolerance applied when comparing positions in Cartesian coordinates 

153 (inherited from the underlying cluster space) """ 

154 return self._cluster_space.position_tolerance 

155 

156 @property 

157 def fractional_position_tolerance(self) -> float: 

158 """ tolerance applied when comparing positions in fractional coordinates 

159 (inherited from the underlying cluster space) """ 

160 return self._cluster_space.fractional_position_tolerance 

161 

162 @property 

163 def primitive_structure(self) -> Atoms: 

164 """ primitive structure on which cluster expansion is based """ 

165 return self._cluster_space.primitive_structure.copy() 

166 

167 def plot_ecis(self, orders=None): 

168 """ Plot ECIs for given orders, default plots for all orders """ 

169 

170 if orders is None: 170 ↛ 172line 170 didn't jump to line 172, because the condition on line 170 was never false

171 orders = self.orders 

172 df = self.to_dataframe() 

173 

174 # plotting 

175 import matplotlib.pyplot as plt 

176 fig = plt.figure() 

177 ax = fig.add_subplot(1, 1, 1) 

178 ax.axhline(y=0.0, c='k', lw=1) 

179 for order in orders: 

180 df_order = df.loc[df['order'] == order] 

181 ax.plot(df_order.radius, df_order.eci, 'o', ms=8, label='order {}'.format(order)) 

182 ax.legend(loc='best') 

183 ax.set_xlabel('Radius') 

184 ax.set_ylabel('ECI') 

185 plt.show() 

186 

187 def __len__(self) -> int: 

188 return len(self._parameters) 

189 

190 def _get_string_representation(self, print_threshold: int = None, 

191 print_minimum: int = 10): 

192 """ String representation of the cluster expansion. """ 

193 cluster_space_repr = self._cluster_space._get_string_representation( 

194 print_threshold, print_minimum).split('\n') 

195 # rescale width 

196 par_col_width = max(len('{:9.3g}'.format(max(self._parameters, key=abs))), len('ECI')) 

197 width = len(cluster_space_repr[0]) + 2 * (len(' | ') + par_col_width) 

198 

199 s = [] # type: List 

200 s += ['{s:=^{n}}'.format(s=' Cluster Expansion ', n=width)] 

201 s += [t for t in cluster_space_repr if re.search(':', t)] 

202 

203 # additional information about number of nonzero parameters 

204 df = self.to_dataframe() 

205 orders = self.orders 

206 nzp_by_order = [np.count_nonzero(df[df.order == order].eci) for order in orders] 

207 assert sum(nzp_by_order) == np.count_nonzero(self.parameters) 

208 s += [' {:38} : {}'.format('total number of nonzero parameters', sum(nzp_by_order))] 

209 line = ' {:38} :'.format('number of nonzero parameters by order') 

210 for order, nzp in zip(orders, nzp_by_order): 

211 line += ' {}= {} '.format(order, nzp) 

212 s += [line] 

213 

214 # table header 

215 s += [''.center(width, '-')] 

216 t = [t for t in cluster_space_repr if 'index' in t] 

217 t += ['{s:^{n}}'.format(s='parameter', n=par_col_width)] 

218 t += ['{s:^{n}}'.format(s='ECI', n=par_col_width)] 

219 s += [' | '.join(t)] 

220 s += [''.center(width, '-')] 

221 

222 # table body 

223 index = 0 

224 while index < len(self): 

225 if (print_threshold is not None and 

226 len(self) > print_threshold and 

227 index >= print_minimum and 

228 index <= len(self) - print_minimum): 

229 index = len(self) - print_minimum 

230 s += [' ...'] 

231 pattern = r'^{:4}'.format(index) 

232 t = [t for t in cluster_space_repr if re.match(pattern, t)] 

233 parameter = self._parameters[index] 

234 t += ['{s:^{n}}'.format(s=f'{parameter:9.3g}', n=par_col_width)] 

235 eci = parameter / self._cluster_space.orbit_data[index]['multiplicity'] 

236 t += ['{s:^{n}}'.format(s=f'{eci:9.3g}', n=par_col_width)] 

237 s += [' | '.join(t)] 

238 index += 1 

239 s += [''.center(width, '=')] 

240 

241 return '\n'.join(s) 

242 

243 def __repr__(self) -> str: 

244 """ string representation """ 

245 return self._get_string_representation(print_threshold=50) 

246 

247 def print_overview(self, 

248 print_threshold: int = None, 

249 print_minimum: int = 10) -> None: 

250 """ 

251 Print an overview of the cluster expansion in terms of the orbits (order, 

252 radius, multiplicity, corresponding ECI etc). 

253 

254 Parameters 

255 ---------- 

256 print_threshold 

257 if the number of orbits exceeds this number print dots 

258 print_minimum 

259 number of lines printed from the top and the bottom of the orbit 

260 list if `print_threshold` is exceeded 

261 """ 

262 print(self._get_string_representation(print_threshold=print_threshold, 

263 print_minimum=print_minimum)) 

264 

265 def prune(self, indices: List[int] = None, tol: float = 0) -> None: 

266 """ 

267 Removes orbits from the cluster expansion (CE), for which the absolute 

268 values of the corresponding parameters are zero or close to zero. This 

269 commonly reduces the computational cost for evaluating the CE and is 

270 therefore recommended prior to using it in production. If the method 

271 is called without arguments orbits will be pruned, for which the ECIs 

272 are strictly zero. Less restrictive pruning can be achieved by setting 

273 the `tol` keyword. 

274 

275 Parameters 

276 ---------- 

277 indices 

278 indices of parameters to remove from the cluster expansion. 

279 tol 

280 all orbits will be pruned for which the absolute parameter value(s) 

281 is/are within this tolerance 

282 """ 

283 

284 # find orbit indices to be removed 

285 if indices is None: 

286 indices = [i for i, param in enumerate( 

287 self.parameters) if np.abs(param) <= tol and i > 0] 

288 df = self.to_dataframe() 

289 indices = list(set(indices)) 

290 

291 if 0 in indices: 

292 raise ValueError('Orbit index cannot be 0 since the zerolet may not be pruned.') 

293 orbit_candidates_for_removal = df.orbit_index[np.array(indices)].tolist() 

294 safe_to_remove_orbits, safe_to_remove_params = [], [] 

295 for oi in set(orbit_candidates_for_removal): 

296 if oi == -1: 296 ↛ 297line 296 didn't jump to line 297, because the condition on line 296 was never true

297 continue 

298 orbit_count = df.orbit_index.tolist().count(oi) 

299 oi_remove_count = orbit_candidates_for_removal.count(oi) 

300 if orbit_count <= oi_remove_count: 300 ↛ 295line 300 didn't jump to line 295, because the condition on line 300 was never false

301 safe_to_remove_orbits.append(oi) 

302 safe_to_remove_params += df.index[df['orbit_index'] == oi].tolist() 

303 

304 # prune cluster space 

305 self._cluster_space._prune_orbit_list(indices=safe_to_remove_orbits) 

306 self._parameters = self._parameters[np.setdiff1d( 

307 np.arange(len(self._parameters)), safe_to_remove_params)] 

308 assert len(self._parameters) == len(self._cluster_space) 

309 

310 def write(self, filename: str) -> None: 

311 """ 

312 Writes ClusterExpansion object to file. 

313 

314 Parameters 

315 --------- 

316 filename 

317 name of file to which to write 

318 """ 

319 self._cluster_space.write(filename) 

320 

321 items = dict() 

322 items['parameters'] = self.parameters 

323 

324 # TODO: remove if condition once metadata is firmly established 

325 if hasattr(self, '_metadata'): 325 ↛ 328line 325 didn't jump to line 328, because the condition on line 325 was never false

326 items['metadata'] = self._metadata 

327 

328 with tarfile.open(name=filename, mode='w') as tar_file: 

329 cs_file = tempfile.NamedTemporaryFile() 

330 self._cluster_space.write(cs_file.name) 

331 tar_file.add(cs_file.name, arcname='cluster_space') 

332 

333 # write items 

334 temp_file = tempfile.TemporaryFile() 

335 pickle.dump(items, temp_file) 

336 temp_file.seek(0) 

337 tar_info = tar_file.gettarinfo(arcname='items', fileobj=temp_file) 

338 tar_file.addfile(tar_info, temp_file) 

339 temp_file.close() 

340 

341 @staticmethod 

342 def read(filename: str): 

343 """ 

344 Reads ClusterExpansion object from file. 

345 

346 Parameters 

347 --------- 

348 filename 

349 file from which to read 

350 """ 

351 with tarfile.open(name=filename, mode='r') as tar_file: 

352 cs_file = tempfile.NamedTemporaryFile() 

353 cs_file.write(tar_file.extractfile('cluster_space').read()) 

354 cs_file.seek(0) 

355 cs = ClusterSpace.read(cs_file.name) 

356 items = pickle.load(tar_file.extractfile('items')) 

357 

358 ce = ClusterExpansion.__new__(ClusterExpansion) 

359 ce._cluster_space = cs 

360 ce._parameters = items['parameters'] 

361 

362 # TODO: remove if condition once metadata is firmly established 

363 if 'metadata' in items: 363 ↛ 366line 363 didn't jump to line 366, because the condition on line 363 was never false

364 ce._metadata = items['metadata'] 

365 

366 assert list(items['parameters']) == list(ce.parameters) 

367 return ce 

368 

369 def _add_default_metadata(self): 

370 """Adds default metadata to metadata dict.""" 

371 import getpass 

372 import socket 

373 from datetime import datetime 

374 from icet import __version__ as icet_version 

375 

376 self._metadata['date_created'] = datetime.now().strftime('%Y-%m-%dT%H:%M:%S') 

377 self._metadata['username'] = getpass.getuser() 

378 self._metadata['hostname'] = socket.gethostname() 

379 self._metadata['icet_version'] = icet_version