Coverage for icet/core/structure_container.py: 97%

220 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-12-26 04:12 +0000

1""" 

2This module provides the StructureContainer class. 

3""" 

4import tarfile 

5import tempfile 

6 

7from typing import BinaryIO, List, TextIO, Tuple, Union 

8 

9import numpy as np 

10import ase.db 

11from ase import Atoms 

12 

13from icet import ClusterSpace 

14from icet.input_output.logging_tools import logger 

15from pandas import DataFrame 

16logger = logger.getChild('structure_container') 

17 

18 

19class StructureContainer: 

20 """This class serves as a container for structure objects as well as their 

21 properties and cluster vectors. 

22 

23 Parameters 

24 ---------- 

25 cluster_space 

26 Cluster space used for evaluating the cluster vectors. 

27 

28 Example 

29 ------- 

30 The following snippet illustrates the initialization and usage of 

31 a :class:`StructureContainer` object. A structure container 

32 provides convenient means for compiling the data needed to train a 

33 cluster expansion, i.e., a sensing matrix and target property values:: 

34 

35 >>> from ase.build import bulk 

36 >>> from icet import ClusterSpace, StructureContainer 

37 >>> from icet.tools import enumerate_structures 

38 >>> from random import random 

39 

40 >>> # create cluster space 

41 >>> prim = bulk('Au') 

42 >>> cs = ClusterSpace(prim, cutoffs=[7.0, 5.0], 

43 ... chemical_symbols=[['Au', 'Pd']]) 

44 

45 >>> # build structure container 

46 >>> sc = StructureContainer(cs) 

47 >>> for structure in enumerate_structures(prim, range(5), ['Au', 'Pd']): 

48 >>> sc.add_structure(structure, 

49 ... properties={'my_random_energy': random()}) 

50 >>> print(sc) 

51 

52 >>> # fetch sensing matrix and target energies 

53 >>> A, y = sc.get_fit_data(key='my_random_energy') 

54 

55 """ 

56 

57 def __init__(self, cluster_space: ClusterSpace): 

58 

59 if not isinstance(cluster_space, ClusterSpace): 

60 raise TypeError('cluster_space must be a ClusterSpace object.') 

61 

62 self._cluster_space = cluster_space 

63 self._structure_list = [] 

64 

65 def __len__(self) -> int: 

66 return len(self._structure_list) 

67 

68 def __getitem__(self, ind: int): 

69 return self._structure_list[ind] 

70 

71 def get_structure_indices(self, user_tag: str = None) -> List[int]: 

72 """ 

73 Returns indices of structures with the given user tag. This 

74 method provides a simple means for filtering structures. The 

75 :attr:`user_tag` is assigned when adding structures via the 

76 :func:`add_structure` method. 

77 

78 Parameters 

79 ---------- 

80 user_tag 

81 The indices of structures with this user tag are returned. 

82 

83 Returns 

84 ------- 

85 List of structure indices. 

86 """ 

87 return [i for i, s in enumerate(self) if user_tag is None or s.user_tag == user_tag] 

88 

89 def _get_string_representation(self, print_threshold: int = None) -> str: 

90 """ 

91 String representation of the structure container that provides an 

92 overview of the structures in the container. 

93 

94 Parameters 

95 ---------- 

96 print_threshold 

97 If the number of structures exceeds this number print dots. 

98 

99 Returns 

100 ------- 

101 String representation of the structure container. 

102 """ 

103 

104 if len(self) == 0: 

105 return 'Empty StructureContainer' 

106 

107 # Number of structures to print before cutting and printing dots 

108 if print_threshold is None or print_threshold >= len(self): 

109 print_threshold = len(self) + 2 

110 

111 # format specifiers for fields in table 

112 def get_format(val): 

113 if isinstance(val, float): 

114 return '{:9.4f}' 

115 else: 

116 return '{}' 

117 

118 # table headers 

119 default_headers = ['index', 'user_tag', 'n_atoms', 'chemical formula'] 

120 property_headers = sorted(set(key for fs in self for key in fs.properties)) 

121 headers = default_headers + property_headers 

122 

123 # collect the table data 

124 str_table = [] 

125 for i, fs in enumerate(self): 

126 default_data = [i, fs.user_tag, len(fs), fs.structure.get_chemical_formula()] 

127 property_data = [fs.properties.get(key, '') for key in property_headers] 

128 str_row = [get_format(d).format(d) for d in default_data+property_data] 

129 str_table.append(str_row) 

130 str_table = np.array(str_table) 

131 

132 # find maximum widths for each column 

133 widths = [] 

134 for i in range(str_table.shape[1]): 

135 data_width = max(len(val) for val in str_table[:, i]) 

136 header_width = len(headers[i]) 

137 widths.append(max([data_width, header_width])) 

138 

139 total_width = sum(widths) + 3 * len(headers) 

140 row_format = ' | '.join('{:'+str(width)+'}' for width in widths) 

141 

142 # Make string representation of table 

143 s = [] 

144 s += ['{s:=^{n}}'.format(s=' Structure Container ', n=total_width)] 

145 s += ['Total number of structures: {}'.format(len(self))] 

146 s += [''.center(total_width, '-')] 

147 s += [row_format.format(*headers)] 

148 s += [''.center(total_width, '-')] 

149 for i, fs_data in enumerate(str_table, start=1): 

150 s += [row_format.format(*fs_data)] 

151 if i+1 >= print_threshold: 

152 s += [' ...'] 

153 s += [row_format.format(*str_table[-1])] 

154 break 

155 s += [''.center(total_width, '=')] 

156 s = '\n'.join(s) 

157 

158 return s 

159 

160 def __str__(self) -> str: 

161 """ String representation. """ 

162 return self._get_string_representation(print_threshold=50) 

163 

164 def _repr_html_(self) -> str: 

165 """ HTML representation. Used, e.g., in jupyter notebooks. """ 

166 s = ['<h4>Structure Container</h4>'] 

167 s += [f'<p>Total number of structures: {len(self)}</p>'] 

168 s += self.to_dataframe()._repr_html_() 

169 return ''.join(s) 

170 

171 def to_dataframe(self) -> DataFrame: 

172 """Summary of :class:`StructureContainer` object in :class:`DataFrame 

173 <pandas.DataFrame>` format. 

174 """ 

175 data = [] 

176 for s in self: 

177 record = dict( 

178 user_tag=s.user_tag, 

179 natoms=len(s), 

180 formula=s.structure.get_chemical_formula('metal'), 

181 ) 

182 record.update(s.properties) 

183 data.append(record) 

184 return DataFrame.from_dict(data) 

185 

186 def add_structure(self, structure: Atoms, user_tag: str = None, 

187 properties: dict = None, allow_duplicate: bool = True, 

188 sanity_check: bool = True): 

189 """ 

190 Adds a structure to the structure container. 

191 

192 Parameters 

193 ---------- 

194 structure 

195 Atomic structure to be added. 

196 user_tag 

197 User tag for labeling structure. 

198 properties 

199 Scalar properties. If properties are not specified the structure 

200 object will be checked for an attached ASE calculator object 

201 with a calculated potential energy. 

202 allow_duplicate 

203 Whether or not to add the structure if there already exists a 

204 structure with identical cluster vector. 

205 sanity_check 

206 Whether or not to carry out a sanity check before adding the 

207 structure. This includes checking occupations and volume. 

208 """ 

209 

210 # structure must have a proper format and label 

211 if not isinstance(structure, Atoms): 211 ↛ 212line 211 didn't jump to line 212, because the condition on line 211 was never true

212 raise TypeError(f'structure must be an ASE Atoms object not {type(structure)}') 

213 

214 if user_tag is not None: 

215 if not isinstance(user_tag, str): 215 ↛ 216line 215 didn't jump to line 216, because the condition on line 215 was never true

216 raise TypeError(f'user_tag must be a string not {type(user_tag)}.') 

217 

218 if sanity_check: 218 ↛ 222line 218 didn't jump to line 222, because the condition on line 218 was never false

219 self._cluster_space.assert_structure_compatibility(structure) 

220 

221 # check for properties in attached calculator 

222 if properties is None: 

223 properties = {} 

224 if structure.calc is not None: 224 ↛ 230line 224 didn't jump to line 230, because the condition on line 224 was never false

225 if not structure.calc.calculation_required(structure, ['energy']): 

226 energy = structure.get_potential_energy() 

227 properties['energy'] = energy / len(structure) 

228 

229 # check if there exist structures with identical cluster vectors 

230 structure_copy = structure.copy() 

231 cv = self._cluster_space.get_cluster_vector(structure_copy) 

232 if not allow_duplicate: 

233 for i, fs in enumerate(self): 

234 if np.allclose(cv, fs.cluster_vector): 

235 msg = '{} and {} have identical cluster vectors'.format( 

236 user_tag if user_tag is not None else 'Input structure', 

237 fs.user_tag if fs.user_tag != 'None' else 'structure') 

238 msg += ' at index {}'.format(i) 

239 raise ValueError(msg) 

240 

241 # add structure 

242 structure = FitStructure(structure_copy, user_tag, cv, properties) 

243 self._structure_list.append(structure) 

244 

245 def get_condition_number(self, structure_indices: List[int] = None) -> float: 

246 """Returns the condition number for the sensing matrix. 

247 

248 A very large condition number can be a sign of 

249 multicollinearity. More information can be found 

250 [here](https://en.wikipedia.org/wiki/Condition_number). 

251 

252 Parameters 

253 ---------- 

254 structure_indices 

255 List of structure indices to include. By default (``None``) the 

256 method will return all fit data available. 

257 

258 Returns 

259 ------- 

260 Condition number of the sensing matrix. 

261 """ 

262 return np.linalg.cond(self.get_fit_data(structure_indices, key=None)[0]) 

263 

264 def get_fit_data(self, structure_indices: List[int] = None, 

265 key: str = 'energy') -> Tuple[np.ndarray, np.ndarray]: 

266 """ 

267 Returns fit data for all structures. The cluster vectors and 

268 target properties for all structures are stacked into numpy arrays. 

269 

270 Parameters 

271 ---------- 

272 structure_indices 

273 List of structure indices. By default (``None``) the 

274 method will return all fit data available. 

275 key 

276 Name of property to use. If ``None`` do not include property values. 

277 This can be useful if only the fit matrix is needed. 

278 

279 Returns 

280 ------- 

281 Cluster vectors and target properties for desired structures. 

282 """ 

283 if structure_indices is None: 

284 cv_list = [s.cluster_vector for s in self._structure_list] 

285 if key is None: 

286 prop_list = None 

287 else: 

288 prop_list = [s.properties[key] for s in self._structure_list] 

289 

290 else: 

291 cv_list, prop_list = [], [] 

292 for i in structure_indices: 

293 cv_list.append(self._structure_list[i].cluster_vector) 

294 if key is None: 294 ↛ 295line 294 didn't jump to line 295, because the condition on line 294 was never true

295 prop_list = None 

296 else: 

297 prop_list.append(self._structure_list[i].properties[key]) 

298 

299 if cv_list is None: 

300 raise Exception(f'No available fit data for {structure_indices}.') 

301 

302 cv_list = np.array(cv_list) 

303 if key is not None: 

304 prop_list = np.array(prop_list) 

305 

306 return cv_list, prop_list 

307 

308 @property 

309 def cluster_space(self) -> ClusterSpace: 

310 """ Cluster space used to calculate the cluster vectors. """ 

311 return self._cluster_space 

312 

313 @property 

314 def available_properties(self) -> List[str]: 

315 """ List of the available properties. """ 

316 return sorted(set([p for fs in self for p in fs.properties.keys()])) 

317 

318 def write(self, outfile: Union[str, BinaryIO, TextIO]) -> None: 

319 """ 

320 Writes structure container to a file. 

321 

322 Parameters 

323 ---------- 

324 outfile 

325 Output file name or file object. 

326 """ 

327 # Write cluster space to tempfile 

328 temp_cs_file = tempfile.NamedTemporaryFile(delete=False) 

329 self.cluster_space.write(temp_cs_file.name) 

330 

331 # Write fit structures as an ASE db in tempfile 

332 temp_db_file = tempfile.NamedTemporaryFile(delete=False) 

333 temp_db_file.close() 

334 if self._structure_list: 

335 db = ase.db.connect(temp_db_file.name, type='db', append=False) 

336 

337 for fit_structure in self._structure_list: 

338 data_dict = {'user_tag': fit_structure.user_tag, 

339 'properties': fit_structure.properties, 

340 'cluster_vector': fit_structure.cluster_vector} 

341 db.write(fit_structure.structure, data=data_dict) 

342 

343 with tarfile.open(outfile, mode='w') as handle: 

344 handle.add(temp_db_file.name, arcname='database') 

345 handle.add(temp_cs_file.name, arcname='cluster_space') 

346 

347 @staticmethod 

348 def read(infile: Union[str, BinaryIO, TextIO]): 

349 """ 

350 Reads :class:`StructureContainer` object from file. 

351 

352 Parameters 

353 ---------- 

354 infile 

355 File from which to read. 

356 """ 

357 if isinstance(infile, str): 

358 filename = infile 

359 else: 

360 filename = infile.name 

361 

362 if not tarfile.is_tarfile(filename): 

363 raise TypeError('{} is not a tar file'.format(filename)) 

364 

365 temp_db_file = tempfile.NamedTemporaryFile(delete=False) 

366 with tarfile.open(mode='r', name=filename) as tar_file: 

367 cs_file = tar_file.extractfile('cluster_space') 

368 temp_db_file.write(tar_file.extractfile('database').read()) 

369 temp_db_file.seek(0) 

370 cluster_space = ClusterSpace.read(cs_file) 

371 database = ase.db.connect(temp_db_file.name, type='db') 

372 

373 structure_container = StructureContainer(cluster_space) 

374 fit_structures = [] 

375 for row in database.select(): 

376 data = row.data 

377 fit_structure = FitStructure(row.toatoms(), 

378 user_tag=data['user_tag'], 

379 cluster_vector=data['cluster_vector'], 

380 properties=data['properties']) 

381 fit_structures.append(fit_structure) 

382 structure_container._structure_list = fit_structures 

383 return structure_container 

384 

385 

386class FitStructure: 

387 """ 

388 This class holds a supercell along with its properties and cluster 

389 vector. 

390 

391 Attributes 

392 ---------- 

393 structure 

394 Supercell structure. 

395 user_tag 

396 Custom user tag. 

397 cluster_vector 

398 Cluster vector. 

399 properties 

400 Dictionary comprising name and value of properties. 

401 """ 

402 

403 def __init__(self, structure: Atoms, user_tag: str, 

404 cluster_vector: np.ndarray, properties: dict = {}): 

405 self._structure = structure 

406 self._user_tag = user_tag 

407 self._cluster_vector = cluster_vector 

408 self.properties = properties 

409 

410 @property 

411 def cluster_vector(self) -> np.ndarray: 

412 """ Cluster vector. """ 

413 return self._cluster_vector 

414 

415 @property 

416 def structure(self) -> Atoms: 

417 """ Atomic structure. """ 

418 return self._structure 

419 

420 @property 

421 def user_tag(self) -> str: 

422 """ Structure label. """ 

423 return str(self._user_tag) 

424 

425 def __getattr__(self, key): 

426 """ Accesses properties if possible and returns value. """ 

427 if key not in self.properties.keys(): 

428 return super().__getattribute__(key) 

429 return self.properties[key] 

430 

431 def __len__(self) -> int: 

432 """ Number of sites in the structure. """ 

433 return len(self._structure) 

434 

435 def __str__(self) -> str: 

436 width = 50 

437 s = [] 

438 s += ['{s:=^{n}}'.format(s=' Fit Structure ', n=width)] 

439 s += [' {:22} : {}'.format('user tag', self.user_tag)] 

440 for k, v in self.properties.items(): 

441 s += [f' {k:22} : {v}'] 

442 t = 'cell metric' 

443 for k, row in enumerate(self.structure.cell[:]): 

444 s += [f' {t:22} : {row}'] 

445 t = '' 

446 t = 'sites' 

447 for site in self.structure: 

448 s += [f' {t:22} : {site.index} {site.symbol:2} {site.position}'] 

449 t = '' 

450 s += [''.center(width, '=')] 

451 return '\n'.join(s) 

452 

453 def _repr_html_(self) -> str: 

454 s = ['<h4>FitStructure</h4>'] 

455 s += ['<table border="1" class="dataframe">'] 

456 s += ['<thead><tr><th style="text-align: left;">Property</th><th>Value</th></tr></thead>'] 

457 s += ['<tbody>'] 

458 s += [f'<tr><td style="text-align: left;">user tag</td><td>{self.user_tag}</td></tr>'] 

459 for key, value in sorted(self.properties.items()): 

460 s += [f'<tr><td style="text-align: left;">{key}</td><td>{value}</td></tr>'] 

461 s += ['</tbody></table>'] 

462 

463 s += ['<table border="1" class="dataframe">'] 

464 s += ['<thead><tr><th style="text-align: left;">Cell</th></tr></thead>'] 

465 s += ['<tbody>'] 

466 for row in self.structure.cell[:]: 

467 s += ['<tr>'] 

468 for c in row: 

469 s += [f'<td>{c}</td>'] 

470 s += ['</tr>'] 

471 s += ['</tbody></table>'] 

472 

473 df = DataFrame(np.array([self.structure.symbols, 

474 self.structure.positions[:, 0], 

475 self.structure.positions[:, 1], 

476 self.structure.positions[:, 2]]).T, 

477 columns=['Species', 'Position x', 'Position y', 'Position z']) 

478 s += df._repr_html_() 

479 return ''.join(s)