Coverage for icet/core/cluster_expansion.py: 97%

195 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-05-06 04:14 +0000

1""" 

2This module provides the ClusterExpansion class. 

3""" 

4 

5import os 

6import pandas as pd 

7import numpy as np 

8import pickle 

9import tempfile 

10import tarfile 

11import re 

12 

13from icet import ClusterSpace 

14from icet.core.structure import Structure 

15from typing import List, Union 

16from ase import Atoms 

17 

18 

19class ClusterExpansion: 

20 """Cluster expansions are obtained by combining a cluster space with a set 

21 of parameters, where the latter is commonly obtained by optimization. 

22 Instances of this class allow one to predict the property of interest for 

23 a given structure. 

24 

25 Note 

26 ---- 

27 Each element of the parameter vector corresponds to an effective cluster 

28 interaction (ECI) multiplied by the multiplicity of the underlying orbit. 

29 

30 Attributes 

31 ---------- 

32 cluster_space 

33 Cluster space that was used for constructing the cluster expansion. 

34 parameters 

35 Parameter vector. 

36 metadata 

37 Metadata dictionary, user-defined metadata to be stored together 

38 with cluster expansion. Will be pickled when CE is written to file. 

39 By default contains icet version, username, hostname and date. 

40 

41 Raises 

42 ------ 

43 ValueError 

44 If :attr:`cluster_space` and :attr:`parameters` differ in length. 

45 

46 Example 

47 ------- 

48 The following snippet illustrates the initialization and usage of a 

49 :class:`ClusterExpansion` object. Here, the parameters are taken to be 

50 a list of ones. Usually, they would be obtained by training with 

51 respect to a set of reference data:: 

52 

53 >>> from ase.build import bulk 

54 >>> from icet import ClusterSpace, ClusterExpansion 

55 

56 >>> # create cluster expansion with fake parameters 

57 >>> prim = bulk('Au') 

58 >>> cs = ClusterSpace(prim, cutoffs=[7.0, 5.0], 

59 ... chemical_symbols=[['Au', 'Pd']]) 

60 >>> parameters = len(cs) * [1.0] 

61 >>> ce = ClusterExpansion(cs, parameters) 

62 

63 >>> # make prediction for supercell 

64 >>> sc = prim.repeat(3) 

65 >>> for k in [1, 4, 7]: 

66 >>> sc[k].symbol = 'Pd' 

67 >>> print(ce.predict(sc)) 

68 """ 

69 

70 def __init__(self, cluster_space: ClusterSpace, parameters: np.array, 

71 metadata: dict = None) -> None: 

72 if len(cluster_space) != len(parameters): 

73 raise ValueError('cluster_space ({}) and parameters ({}) must have' 

74 ' the same length'.format(len(cluster_space), len(parameters))) 

75 self._cluster_space = cluster_space.copy() 

76 if isinstance(parameters, list): 

77 parameters = np.array(parameters) 

78 self._parameters = parameters 

79 

80 # add metadata 

81 if metadata is None: 

82 metadata = dict() 

83 self._metadata = metadata 

84 self._add_default_metadata() 

85 

86 def predict(self, structure: Union[Atoms, Structure]) -> float: 

87 """ 

88 Returns the property value predicted by the cluster expansion. 

89 

90 Parameters 

91 ---------- 

92 structure 

93 Atomic configuration. 

94 """ 

95 cluster_vector = self._cluster_space.get_cluster_vector(structure) 

96 prop = np.dot(cluster_vector, self.parameters) 

97 return prop 

98 

99 def get_cluster_space_copy(self) -> ClusterSpace: 

100 """ Returns copy of cluster space on which cluster expansion is based. """ 

101 return self._cluster_space.copy() 

102 

103 def to_dataframe(self) -> pd.DataFrame: 

104 """Returns a representation of the cluster expansion in the form of a 

105 DataFrame including effective cluster interactions (ECIs).""" 

106 rows = self._cluster_space.as_list 

107 for row, param in zip(rows, self.parameters): 

108 row['parameter'] = param 

109 row['eci'] = param / row['multiplicity'] 

110 df = pd.DataFrame(rows) 

111 del df['index'] 

112 return df 

113 

114 @property 

115 def chemical_symbols(self) -> List[List[str]]: 

116 """ Species identified by their chemical symbols (copy). """ 

117 return self._cluster_space.chemical_symbols.copy() 

118 

119 @property 

120 def cutoffs(self) -> List[float]: 

121 """ 

122 Cutoffs for different n-body clusters (copy). The cutoff radius (in 

123 Ångstroms) defines the largest interatomic distance in a 

124 cluster. 

125 """ 

126 return self._cluster_space.cutoffs.copy() 

127 

128 @property 

129 def orders(self) -> List[int]: 

130 """ Orders included in cluster expansion. """ 

131 return list(range(len(self._cluster_space.cutoffs) + 2)) 

132 

133 @property 

134 def parameters(self) -> List[float]: 

135 """ Parameter vector. Each element of the parameter vector corresponds 

136 to an effective cluster interaction (ECI) multiplied by the 

137 multiplicity of the respective orbit. """ 

138 return self._parameters 

139 

140 @property 

141 def metadata(self) -> dict: 

142 """ Metadata associated with the cluster expansion. """ 

143 return self._metadata 

144 

145 @property 

146 def symprec(self) -> float: 

147 """ Tolerance imposed when analyzing the symmetry using spglib 

148 (inherited from the underlying cluster space). """ 

149 return self._cluster_space.symprec 

150 

151 @property 

152 def position_tolerance(self) -> float: 

153 """ Tolerance applied when comparing positions in Cartesian coordinates 

154 (inherited from the underlying cluster space). """ 

155 return self._cluster_space.position_tolerance 

156 

157 @property 

158 def fractional_position_tolerance(self) -> float: 

159 """ Tolerance applied when comparing positions in fractional coordinates 

160 (inherited from the underlying cluster space). """ 

161 return self._cluster_space.fractional_position_tolerance 

162 

163 @property 

164 def primitive_structure(self) -> Atoms: 

165 """ Primitive structure on which cluster expansion is based. """ 

166 return self._cluster_space.primitive_structure.copy() 

167 

168 def __len__(self) -> int: 

169 return len(self._parameters) 

170 

171 def _get_string_representation(self, print_threshold: int = None, 

172 print_minimum: int = 10): 

173 """ String representation of the cluster expansion. """ 

174 cluster_space_repr = self._cluster_space._get_string_representation( 

175 print_threshold, print_minimum).split('\n') 

176 # rescale width 

177 par_col_width = max(len('{:9.3g}'.format(max(self._parameters, key=abs))), len('ECI')) 

178 width = len(cluster_space_repr[0]) + 2 * (len(' | ') + par_col_width) 

179 

180 s = [] 

181 s += ['{s:=^{n}}'.format(s=' Cluster Expansion ', n=width)] 

182 s += [t for t in cluster_space_repr if re.search(':', t)] 

183 

184 # additional information about number of nonzero parameters 

185 df = self.to_dataframe() 

186 orders = self.orders 

187 nzp_by_order = [np.count_nonzero(df[df.order == order].eci) for order in orders] 

188 assert sum(nzp_by_order) == np.count_nonzero(self.parameters) 

189 s += [' {:38} : {}'.format('total number of nonzero parameters', sum(nzp_by_order))] 

190 line = ' {:38} :'.format('number of nonzero parameters by order') 

191 for order, nzp in zip(orders, nzp_by_order): 

192 line += ' {}= {} '.format(order, nzp) 

193 s += [line] 

194 

195 # table header 

196 s += [''.center(width, '-')] 

197 t = [t for t in cluster_space_repr if 'index' in t] 

198 t += ['{s:^{n}}'.format(s='parameter', n=par_col_width)] 

199 t += ['{s:^{n}}'.format(s='ECI', n=par_col_width)] 

200 s += [' | '.join(t)] 

201 s += [''.center(width, '-')] 

202 

203 # table body 

204 index = 0 

205 while index < len(self): 

206 if (print_threshold is not None and 

207 len(self) > print_threshold and 

208 index >= print_minimum and 

209 index <= len(self) - print_minimum): 

210 index = len(self) - print_minimum 

211 s += [' ...'] 

212 pattern = r'^{:4}'.format(index) 

213 t = [t for t in cluster_space_repr if re.match(pattern, t)] 

214 parameter = self._parameters[index] 

215 t += ['{s:^{n}}'.format(s=f'{parameter:9.3g}', n=par_col_width)] 

216 eci = parameter / self._cluster_space.as_list[index]['multiplicity'] 

217 t += ['{s:^{n}}'.format(s=f'{eci:9.3g}', n=par_col_width)] 

218 s += [' | '.join(t)] 

219 index += 1 

220 s += [''.center(width, '=')] 

221 

222 return '\n'.join(s) 

223 

224 def __str__(self) -> str: 

225 """ String representation. """ 

226 return self._get_string_representation(print_threshold=50) 

227 

228 def _repr_html_(self) -> str: 

229 """ HTML representation. Used, e.g., in jupyter notebooks. """ 

230 s = ['<h4>Cluster Expansion</h4>'] 

231 s += ['<table border="1" class="dataframe">'] 

232 s += ['<thead><tr><th style="text-align: left;">Field</th><th>Value</th></tr></thead>'] 

233 s += ['<tbody>'] 

234 s += ['<tr><td style="text-align: left;">Space group</td>' 

235 f'<td>{self._cluster_space.space_group}</td></tr>'] 

236 for sl in self._cluster_space.get_sublattices( 

237 self.primitive_structure).active_sublattices: 

238 s += [f'<tr><td style="text-align: left;">Sublattice {sl.symbol}</td>' 

239 f'<td>{sl.chemical_symbols}</td></tr>'] 

240 s += ['<tr><td style="text-align: left;">Cutoffs</td>' 

241 f'<td>{self._cluster_space.cutoffs}</td></tr>'] 

242 

243 df = self.to_dataframe() 

244 nzp_by_order = [np.count_nonzero(df[df.order == order].eci) for order in self.orders] 

245 assert sum(nzp_by_order) == np.count_nonzero(self.parameters) 

246 s += ['<tr><td style="text-align: left;">Total number of parameters (nonzero)</td>' 

247 f'<td>{len(self)} ({sum(nzp_by_order)})</td></tr>'] 

248 for (order, npar), nzp in zip( 

249 self._cluster_space.number_of_orbits_by_order.items(), nzp_by_order): 

250 s += ['<tr><td style="text-align: left;">' 

251 f'Number of parameters of order {order} (nonzero)</td>' 

252 f'<td>{npar} ({nzp})</td></tr>'] 

253 s += ['<tr><td style="text-align: left;">fractional_position_tolerance</td>' 

254 f'<td>{self._cluster_space.fractional_position_tolerance}</td></tr>'] 

255 s += ['<tr><td style="text-align: left;">position_tolerance</td>' 

256 f'<td>{self._cluster_space.position_tolerance}</td></tr>'] 

257 s += ['<tr><td style="text-align: left;">symprec</td>' 

258 f'<td>{self._cluster_space.symprec}</td></tr>'] 

259 

260 s += ['</tbody>'] 

261 s += ['</table>'] 

262 return ''.join(s) 

263 

264 def __repr__(self) -> str: 

265 """ Representation. """ 

266 s = type(self).__name__ + '(' 

267 s += f'cluster_space={self._cluster_space.__repr__()}' 

268 s += f', parameters={list(self._parameters).__repr__()}' 

269 s += ')' 

270 return s 

271 

272 def prune(self, indices: List[int] = None, tol: float = 0) -> None: 

273 """Removes orbits from the cluster expansion, for which the absolute 

274 values of the corresponding parameters are zero or close to 

275 zero. This commonly reduces the computational cost for 

276 evaluating the cluster expansion. It is therefore recommended 

277 to apply this method prior to using the cluster expansion in 

278 production. If the method is called without arguments only 

279 orbits will be pruned, for which the ECIs are strictly zero. 

280 Less restrictive pruning can be achieved by setting the 

281 :attr:`tol` keyword. 

282 

283 Parameters 

284 ---------- 

285 indices 

286 Indices of parameters to remove from the cluster expansion. 

287 tol 

288 All orbits will be pruned for which the absolute parameter value(s) 

289 is/are within this tolerance. 

290 """ 

291 

292 # find orbit indices to be removed 

293 if indices is None: 

294 indices = [i for i, param in enumerate( 

295 self.parameters) if np.abs(param) <= tol and i > 0] 

296 df = self.to_dataframe() 

297 indices = list(set(indices)) 

298 

299 if 0 in indices: 

300 raise ValueError('Orbit index cannot be 0 since the zerolet may not be pruned.') 

301 orbit_candidates_for_removal = df.orbit_index[np.array(indices)].tolist() 

302 safe_to_remove_orbits, safe_to_remove_params = [], [] 

303 for oi in set(orbit_candidates_for_removal): 

304 if oi == -1: 304 ↛ 305line 304 didn't jump to line 305, because the condition on line 304 was never true

305 continue 

306 orbit_count = df.orbit_index.tolist().count(oi) 

307 oi_remove_count = orbit_candidates_for_removal.count(oi) 

308 if orbit_count <= oi_remove_count: 

309 safe_to_remove_orbits.append(oi) 

310 safe_to_remove_params += df.index[df['orbit_index'] == oi].tolist() 

311 

312 # prune cluster space 

313 self._cluster_space._prune_orbit_list(indices=safe_to_remove_orbits) 

314 self._parameters = self._parameters[np.setdiff1d( 

315 np.arange(len(self._parameters)), safe_to_remove_params)] 

316 assert len(self._parameters) == len(self._cluster_space) 

317 

318 def write(self, filename: str) -> None: 

319 """ 

320 Writes ClusterExpansion object to file. 

321 

322 Parameters 

323 --------- 

324 filename 

325 name of file to which to write 

326 """ 

327 self._cluster_space.write(filename) 

328 

329 items = dict() 

330 items['parameters'] = self.parameters 

331 

332 # TODO: remove if condition once metadata is firmly established 

333 if hasattr(self, '_metadata'): 333 ↛ 336line 333 didn't jump to line 336, because the condition on line 333 was never false

334 items['metadata'] = self._metadata 

335 

336 with tarfile.open(name=filename, mode='w') as tar_file: 

337 cs_file = tempfile.NamedTemporaryFile(delete=False) 

338 cs_file.close() 

339 self._cluster_space.write(cs_file.name) 

340 tar_file.add(cs_file.name, arcname='cluster_space') 

341 

342 # write items 

343 temp_file = tempfile.TemporaryFile() 

344 pickle.dump(items, temp_file) 

345 temp_file.seek(0) 

346 tar_info = tar_file.gettarinfo(arcname='items', fileobj=temp_file) 

347 tar_file.addfile(tar_info, temp_file) 

348 os.remove(cs_file.name) 

349 temp_file.close() 

350 

351 @staticmethod 

352 def read(filename: str): 

353 """ 

354 Reads :class:`ClusterExpansion` object from file. 

355 

356 Parameters 

357 --------- 

358 filename 

359 File from which to read. 

360 """ 

361 with tarfile.open(name=filename, mode='r') as tar_file: 

362 cs_file = tempfile.NamedTemporaryFile(delete=False) 

363 cs_file.write(tar_file.extractfile('cluster_space').read()) 

364 cs_file.close() 

365 cs = ClusterSpace.read(cs_file.name) 

366 items = pickle.load(tar_file.extractfile('items')) 

367 os.remove(cs_file.name) 

368 

369 ce = ClusterExpansion.__new__(ClusterExpansion) 

370 ce._cluster_space = cs 

371 ce._parameters = items['parameters'] 

372 

373 # TODO: remove if condition once metadata is firmly established 

374 if 'metadata' in items: 374 ↛ 377line 374 didn't jump to line 377, because the condition on line 374 was never false

375 ce._metadata = items['metadata'] 

376 

377 assert list(items['parameters']) == list(ce.parameters) 

378 return ce 

379 

380 def _add_default_metadata(self): 

381 """ Adds default metadata to metadata dict. """ 

382 import getpass 

383 import socket 

384 from datetime import datetime 

385 from icet import __version__ as icet_version 

386 

387 self._metadata['date_created'] = datetime.now().strftime('%Y-%m-%dT%H:%M:%S') 

388 self._metadata['username'] = getpass.getuser() 

389 self._metadata['hostname'] = socket.gethostname() 

390 self._metadata['icet_version'] = icet_version