Coverage for icet/core/cluster_space.py: 97%

316 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2025-09-14 04:08 +0000

1""" 

2This module provides the :class:`ClusterSpace` class. 

3""" 

4 

5import os 

6import copy 

7import itertools 

8import pickle 

9import tarfile 

10import tempfile 

11from collections.abc import Iterable 

12from math import log10, floor 

13from typing import Union 

14 

15import numpy as np 

16import spglib 

17 

18from _icet import ClusterSpace as _ClusterSpace 

19from ase import Atoms 

20from ase.io import read as ase_read 

21from ase.io import write as ase_write 

22from icet.core.orbit_list import OrbitList 

23from icet.core.structure import Structure 

24from icet.core.sublattices import Sublattices 

25from icet.tools.geometry import (ase_atoms_to_spglib_cell, 

26 get_occupied_primitive_structure) 

27from pandas import DataFrame 

28 

29 

30class ClusterSpace(_ClusterSpace): 

31 """This class provides functionality for generating and maintaining 

32 cluster spaces. 

33 

34 Note 

35 ---- 

36 In :program:`icet` all :class:`Atoms <ase.Atoms>` objects must have 

37 periodic boundary conditions. When constructing cluster expansions 

38 for surfaces and nanoparticles it is therefore recommended to 

39 surround the structure with vacuum and use periodic boundary 

40 conditions. This can be achieved by using :func:`Atoms.center <ase.Atoms.center>`. 

41 

42 Parameters 

43 ---------- 

44 structure 

45 Atomic structure. 

46 cutoffs 

47 Cutoff radii per order that define the cluster space. 

48 

49 Cutoffs are specified in units of Ångstrom and refer to the 

50 longest distance between two atoms in the cluster. The first 

51 element refers to pairs, the second to triplets, the third 

52 to quadruplets, and so on. :attr:`cutoffs=[7.0, 4.5]` thus implies 

53 that all pairs distanced 7 Å or less will be included, 

54 as well as all triplets among which the longest distance is no 

55 longer than 4.5 Å. 

56 chemical_symbols 

57 List of chemical symbols, each of which must map to an element 

58 of the periodic table. 

59 

60 If a list of chemical symbols is provided, all sites on the 

61 lattice will have the same allowed occupations as the input 

62 list. 

63 

64 If a list of list of chemical symbols is provided then the 

65 outer list must be the same length as the :attr:`structure` object and 

66 :attr:`chemical_symbols[i]` will correspond to the allowed species 

67 on lattice site ``i``. 

68 symprec 

69 Tolerance imposed when analyzing the symmetry using spglib. 

70 position_tolerance 

71 Tolerance applied when comparing positions in Cartesian coordinates. 

72 

73 Examples 

74 -------- 

75 The following snippets illustrate several common situations:: 

76 

77 >>> from ase.build import bulk 

78 >>> from ase.io import read 

79 >>> from icet import ClusterSpace 

80 

81 >>> # AgPd alloy with pairs up to 7.0 A and triplets up to 4.5 A 

82 >>> prim = bulk('Ag') 

83 >>> cs = ClusterSpace(structure=prim, cutoffs=[7.0, 4.5], 

84 ... chemical_symbols=[['Ag', 'Pd']]) 

85 >>> print(cs) 

86 

87 >>> # (Mg,Zn)O alloy on rocksalt lattice with pairs up to 8.0 A 

88 >>> prim = bulk('MgO', crystalstructure='rocksalt', a=6.0) 

89 >>> cs = ClusterSpace(structure=prim, cutoffs=[8.0], 

90 ... chemical_symbols=[['Mg', 'Zn'], ['O']]) 

91 >>> print(cs) 

92 

93 >>> # (Ga,Al)(As,Sb) alloy with pairs, triplets, and quadruplets 

94 >>> prim = bulk('GaAs', crystalstructure='zincblende', a=6.5) 

95 >>> cs = ClusterSpace(structure=prim, cutoffs=[7.0, 6.0, 5.0], 

96 ... chemical_symbols=[['Ga', 'Al'], ['As', 'Sb']]) 

97 >>> print(cs) 

98 

99 >>> # PdCuAu alloy with pairs and triplets 

100 >>> prim = bulk('Pd') 

101 >>> cs = ClusterSpace(structure=prim, cutoffs=[7.0, 5.0], 

102 ... chemical_symbols=[['Au', 'Cu', 'Pd']]) 

103 >>> print(cs) 

104 

105 """ 

106 

107 def __init__(self, 

108 structure: Atoms, 

109 cutoffs: list[float], 

110 chemical_symbols: Union[list[str], list[list[str]]], 

111 symprec: float = 1e-5, 

112 position_tolerance: float = None) -> None: 

113 

114 if not isinstance(structure, Atoms): 114 ↛ 115line 114 didn't jump to line 115 because the condition on line 114 was never true

115 raise TypeError('Input configuration must be an ASE Atoms object' 

116 f', not type {type(structure)}.') 

117 if not all(structure.pbc): 

118 raise ValueError('Input structure must be periodic.') 

119 if symprec <= 0: 

120 raise ValueError('symprec must be a positive number.') 

121 

122 self._config = {'symprec': symprec} 

123 self._cutoffs = cutoffs.copy() 

124 self._input_structure = structure.copy() 

125 self._input_chemical_symbols = copy.deepcopy(chemical_symbols) 

126 chemical_symbols = self._get_chemical_symbols() 

127 

128 self._pruning_history: list[tuple] = [] 

129 

130 # set up primitive 

131 occupied_primitive, primitive_chemical_symbols = get_occupied_primitive_structure( 

132 self._input_structure, chemical_symbols, symprec=self.symprec) 

133 self._primitive_chemical_symbols = primitive_chemical_symbols 

134 assert len(occupied_primitive) == len(primitive_chemical_symbols) 

135 

136 # derived tolerances 

137 if position_tolerance is None: 

138 self._config['position_tolerance'] = symprec 

139 else: 

140 if position_tolerance <= 0: 

141 raise ValueError('position_tolerance must be a positive number') 

142 self._config['position_tolerance'] = position_tolerance 

143 effective_box_size = abs(np.linalg.det(occupied_primitive.cell)) ** (1 / 3) 

144 tol = self.position_tolerance / effective_box_size 

145 tol = min(tol, self._config['position_tolerance'] / 5) 

146 self._config['fractional_position_tolerance'] = round(tol, -int(floor(log10(abs(tol))))) 

147 

148 # set up orbit list 

149 self._orbit_list = OrbitList( 

150 structure=occupied_primitive, 

151 cutoffs=self._cutoffs, 

152 chemical_symbols=self._primitive_chemical_symbols, 

153 symprec=self.symprec, 

154 position_tolerance=self.position_tolerance, 

155 fractional_position_tolerance=self.fractional_position_tolerance) 

156 self._orbit_list.remove_orbits_with_inactive_sites() 

157 

158 # call (base) C++ constructor 

159 _ClusterSpace.__init__(self, 

160 orbit_list=self._orbit_list, 

161 position_tolerance=self.position_tolerance, 

162 fractional_position_tolerance=self.fractional_position_tolerance) 

163 

164 def _get_chemical_symbols(self): 

165 """ Returns chemical symbols using input structure and 

166 chemical symbols. Carries out multiple sanity checks. """ 

167 

168 # set up chemical symbols as list[list[str]] 

169 if all(isinstance(i, str) for i in self._input_chemical_symbols): 

170 chemical_symbols = [self._input_chemical_symbols] * len(self._input_structure) 

171 # also accept tuples and other iterables but not, e.g., list[list, str] 

172 # (need to check for str explicitly here because str is an Iterable) 

173 elif not all(isinstance(i, Iterable) and not isinstance(i, str) 

174 for i in self._input_chemical_symbols): 

175 raise TypeError('chemical_symbols must be list[str] or list[list[str]], not {}'.format( 

176 type(self._input_chemical_symbols))) 

177 elif len(self._input_chemical_symbols) != len(self._input_structure): 

178 msg = 'chemical_symbols must have same length as structure. ' 

179 msg += 'len(chemical_symbols) = {}, len(structure)= {}'.format( 

180 len(self._input_chemical_symbols), len(self._input_structure)) 

181 raise ValueError(msg) 

182 else: 

183 chemical_symbols = copy.deepcopy(self._input_chemical_symbols) 

184 

185 for i, symbols in enumerate(chemical_symbols): 

186 if len(symbols) != len(set(symbols)): 

187 raise ValueError( 

188 'Found duplicates of allowed chemical symbols on site {}.' 

189 ' allowed species on site {}= {}'.format(i, i, symbols)) 

190 

191 if len([tuple(sorted(s)) for s in chemical_symbols if len(s) > 1]) == 0: 

192 raise ValueError('No active sites found') 

193 

194 return chemical_symbols 

195 

196 def _get_chemical_symbol_representation(self): 

197 """Returns a str version of the chemical symbols that is 

198 easier on the eyes. 

199 """ 

200 sublattices = self.get_sublattices(self.primitive_structure) 

201 nice_str = [] 

202 for sublattice in sublattices.active_sublattices: 

203 sublattice_symbol = sublattice.symbol 

204 nice_str.append('{} (sublattice {})'.format( 

205 list(sublattice.chemical_symbols), sublattice_symbol)) 

206 return ', '.join(nice_str) 

207 

208 def _get_string_representation(self, 

209 print_threshold: int = None, 

210 print_minimum: int = 10) -> str: 

211 """ 

212 String representation of the cluster space that provides an overview of 

213 the orbits (order, radius, multiplicity etc) that constitute the space. 

214 

215 Parameters 

216 ---------- 

217 print_threshold 

218 if the number of orbits exceeds this number print dots 

219 print_minimum 

220 number of lines printed from the top and the bottom of the orbit 

221 list if `print_threshold` is exceeded 

222 

223 Returns 

224 ------- 

225 multi-line string 

226 string representation of the cluster space. 

227 """ 

228 

229 def repr_orbit(orbit, header=False): 

230 formats = {'order': '{:2}', 

231 'radius': '{:8.4f}', 

232 'multiplicity': '{:4}', 

233 'index': '{:4}', 

234 'orbit_index': '{:4}', 

235 'multicomponent_vector': '{:}', 

236 'sublattices': '{:}'} 

237 s = [] 

238 for name, value in orbit.items(): 

239 if name == 'sublattices': 

240 str_repr = formats[name].format('-'.join(value)) 

241 else: 

242 str_repr = formats[name].format(value) 

243 n = max(len(name), len(str_repr)) 

244 if header: 

245 s += ['{s:^{n}}'.format(s=name, n=n)] 

246 else: 

247 s += ['{s:^{n}}'.format(s=str_repr, n=n)] 

248 return ' | '.join(s) 

249 

250 # basic information 

251 # (use largest orbit to obtain maximum line length) 

252 prototype_orbit = self.as_list[-1] 

253 width = len(repr_orbit(prototype_orbit)) 

254 s = [] 

255 s += ['{s:=^{n}}'.format(s=' Cluster Space ', n=width)] 

256 s += [' {:38} : {}'.format('space group', self.space_group)] 

257 s += [' {:38} : {}' 

258 .format('chemical species', self._get_chemical_symbol_representation())] 

259 s += [' {:38} : {}'.format('cutoffs', 

260 ' '.join(['{:.4f}'.format(c) for c in self.cutoffs]))] 

261 s += [' {:38} : {}'.format('total number of parameters', len(self))] 

262 t = ['{}= {}'.format(k, c) 

263 for k, c in self.number_of_orbits_by_order.items()] 

264 s += [' {:38} : {}'.format('number of parameters by order', ' '.join(t))] 

265 for key, value in sorted(self._config.items()): 

266 s += [' {:38} : {}'.format(key, value)] 

267 

268 # table header 

269 s += [''.center(width, '-')] 

270 s += [repr_orbit(prototype_orbit, header=True)] 

271 s += [''.center(width, '-')] 

272 

273 # table body 

274 index = 0 

275 orbit_list_info = self.as_list 

276 while index < len(orbit_list_info): 

277 if (print_threshold is not None and 

278 len(self) > print_threshold and 

279 index >= print_minimum and 

280 index <= len(self) - print_minimum): 

281 index = len(self) - print_minimum 

282 s += [' ...'] 

283 s += [repr_orbit(orbit_list_info[index])] 

284 index += 1 

285 s += [''.center(width, '=')] 

286 

287 return '\n'.join(s) 

288 

289 def __str__(self) -> str: 

290 """ String representation. """ 

291 return self._get_string_representation(print_threshold=50) 

292 

293 def _repr_html_(self) -> str: 

294 """ HTML representation. Used, e.g., in jupyter notebooks. """ 

295 s = ['<h4>Cluster Space</h4>'] 

296 s += ['<table border="1" class="dataframe">'] 

297 s += ['<thead><tr><th style="text-align: left;">Field</th><th>Value</th></tr></thead>'] 

298 s += ['<tbody>'] 

299 s += [f'<tr><td style="text-align: left;">Space group</td><td>{self.space_group}</td></tr>'] 

300 for sl in self.get_sublattices(self.primitive_structure).active_sublattices: 

301 s += [f'<tr><td style="text-align: left;">Sublattice {sl.symbol}</td>' 

302 f'<td>{sl.chemical_symbols}</td></tr>'] 

303 s += [f'<tr><td style="text-align: left;">Cutoffs</td><td>{self.cutoffs}</td></tr>'] 

304 s += ['<tr><td style="text-align: left;">Total number of parameters</td>' 

305 f'<td>{len(self)}</td></tr>'] 

306 for k, n in self.number_of_orbits_by_order.items(): 

307 s += [f'<tr><td style="text-align: left;">Number of parameters of order {k}</td>' 

308 f'<td>{n}</td></tr>'] 

309 for key, value in sorted(self._config.items()): 

310 s += [f'<tr><td style="text-align: left;">{key}</td><td>{value}</td></tr>'] 

311 s += ['</tbody>'] 

312 s += ['</table>'] 

313 return ''.join(s) 

314 

315 def __repr__(self) -> str: 

316 """ Representation. """ 

317 s = type(self).__name__ + '(' 

318 s += f'structure={self.primitive_structure.__repr__()}' 

319 s += f', cutoffs={self._cutoffs.__repr__()}' 

320 s += f', chemical_symbols={self._input_chemical_symbols.__repr__()}' 

321 s += f', position_tolerance={self._config["position_tolerance"]}' 

322 s += ')' 

323 return s 

324 

325 def __getitem__(self, ind: int): 

326 return self.as_list[ind] 

327 

328 @property 

329 def symprec(self) -> float: 

330 """ Tolerance imposed when analyzing the symmetry using spglib. """ 

331 return self._config['symprec'] 

332 

333 @property 

334 def position_tolerance(self) -> float: 

335 """ Tolerance applied when comparing positions in Cartesian coordinates. """ 

336 return self._config['position_tolerance'] 

337 

338 @property 

339 def fractional_position_tolerance(self) -> float: 

340 """ Tolerance applied when comparing positions in fractional coordinates. """ 

341 return self._config['fractional_position_tolerance'] 

342 

343 @property 

344 def space_group(self) -> str: 

345 """ Space group of the primitive structure in international notion (via spglib). """ 

346 structure_as_tuple = ase_atoms_to_spglib_cell(self.primitive_structure) 

347 return spglib.get_spacegroup(structure_as_tuple, symprec=self._config['symprec']) 

348 

349 @property 

350 def as_list(self) -> list[dict]: 

351 """Representation of cluster space as list with information regarding 

352 order, radius, multiplicity etc. 

353 """ 

354 data = [] 

355 zerolet = dict( 

356 index=0, 

357 order=0, 

358 radius=0, 

359 multiplicity=1, 

360 orbit_index=-1, 

361 multicomponent_vector='.', 

362 sublattices='.', 

363 ) 

364 data.append(zerolet) 

365 

366 sublattices = self.get_sublattices(self.primitive_structure) 

367 index = 0 

368 for orbit_index in range(len(self.orbit_list)): 

369 orbit = self.orbit_list.get_orbit(orbit_index) 

370 representative_cluster = orbit.representative_cluster 

371 orbit_sublattices = [ 

372 sublattices[sublattices.get_sublattice_index_from_site_index(ls.index)].symbol 

373 for ls in representative_cluster.lattice_sites] 

374 for cv_element in orbit.cluster_vector_elements: 

375 index += 1 

376 data.append(dict( 

377 index=index, 

378 order=representative_cluster.order, 

379 radius=representative_cluster.radius, 

380 multiplicity=cv_element['multiplicity'], 

381 orbit_index=orbit_index, 

382 multicomponent_vector=cv_element['multicomponent_vector'], 

383 sublattices=orbit_sublattices 

384 )) 

385 return data 

386 

387 def to_dataframe(self) -> DataFrame: 

388 """ Returns a representation of the cluster space as a DataFrame. """ 

389 df = DataFrame.from_dict(self.as_list) 

390 del df['index'] 

391 return df 

392 

393 @property 

394 def number_of_orbits_by_order(self) -> dict: 

395 """ Number of orbits by order in the form of a dictionary 

396 where keys and values represent order and number of orbits, 

397 respectively. 

398 """ 

399 count_orbits: dict[int, int] = {} 

400 for orbit in self.as_list: 

401 k = orbit['order'] 

402 count_orbits[k] = count_orbits.get(k, 0) + 1 

403 return dict(sorted(count_orbits.items())) 

404 

405 def get_cluster_vector(self, structure: Atoms) -> np.ndarray: 

406 """ 

407 Returns the cluster vector for a structure. 

408 

409 Parameters 

410 ---------- 

411 structure 

412 Atomic configuration. 

413 """ 

414 if not isinstance(structure, Atoms): 414 ↛ 415line 414 didn't jump to line 415 because the condition on line 414 was never true

415 raise TypeError('Input structure must be an ASE Atoms object') 

416 

417 try: 

418 cv = _ClusterSpace.get_cluster_vector( 

419 self, 

420 structure=Structure.from_atoms(structure), 

421 fractional_position_tolerance=self.fractional_position_tolerance) 

422 except Exception as e: 

423 self.assert_structure_compatibility(structure) 

424 raise Exception(str(e)) 

425 return cv 

426 

427 def get_coordinates_of_representative_cluster(self, orbit_index: int) -> list[tuple[float]]: 

428 """ 

429 Returns the positions of the sites in the representative cluster of the selected orbit. 

430 

431 Parameters 

432 ---------- 

433 orbit_index 

434 Index of the orbit for which to return the positions of the sites. 

435 """ 

436 # Raise exception if chosen orbit index not in current list of orbit indices 

437 if orbit_index not in range(len(self._orbit_list)): 

438 raise ValueError('The input orbit index is not in the list of possible values.') 

439 return self._orbit_list.get_orbit(orbit_index).representative_cluster.positions 

440 

441 def _remove_orbits(self, indices: list[int]) -> None: 

442 """ 

443 Removes orbits. 

444 

445 Parameters 

446 ---------- 

447 indices 

448 Indices to all orbits to be removed. 

449 """ 

450 size_before = len(self._orbit_list) 

451 

452 # Since we remove orbits, orbit indices will change, 

453 # so we run over the orbits in reverse order. 

454 for ind in reversed(sorted(indices)): 

455 self._orbit_list.remove_orbit(ind) 

456 

457 size_after = len(self._orbit_list) 

458 assert size_before - len(indices) == size_after 

459 

460 def prune_orbit_list(self, indices: list[int]) -> None: 

461 """ 

462 Prunes the internal orbit list and maintains the history. 

463 

464 Parameters 

465 ---------- 

466 indices 

467 Indices to all orbits to be removed. 

468 """ 

469 self._remove_orbits(indices) 

470 self._pruning_history.append(('prune', indices)) 

471 

472 @property 

473 def primitive_structure(self) -> Atoms: 

474 """ Primitive structure on which cluster space is based. """ 

475 structure = self._get_primitive_structure().to_atoms() 

476 # Decorate with the "real" symbols (instead of H, He, Li etc) 

477 for atom, symbols in zip(structure, self._primitive_chemical_symbols): 

478 atom.symbol = min(symbols) 

479 return structure 

480 

481 @property 

482 def chemical_symbols(self) -> list[list[str]]: 

483 """ Species identified by their chemical symbols. """ 

484 return self._primitive_chemical_symbols.copy() 

485 

486 @property 

487 def cutoffs(self) -> list[float]: 

488 """ 

489 Cutoffs for different n-body clusters. The cutoff radius (in 

490 Ångstroms) defines the largest interatomic distance in a 

491 cluster. 

492 """ 

493 return self._cutoffs 

494 

495 @property 

496 def orbit_list(self): 

497 """ Orbit list that defines the cluster in the cluster space. """ 

498 return self._orbit_list 

499 

500 def get_possible_orbit_occupations(self, orbit_index: int) -> list[list[str]]: 

501 """ Returns possible occupations of the orbit. 

502 

503 Parameters 

504 ---------- 

505 orbit_index 

506 Index of orbit of interest. 

507 """ 

508 orbit = self.orbit_list.orbits[orbit_index] 

509 indices = [ls.index for ls in orbit.representative_cluster.lattice_sites] 

510 allowed_species = [self.chemical_symbols[index] for index in indices] 

511 return list(itertools.product(*allowed_species)) 

512 

513 def get_sublattices(self, structure: Atoms) -> Sublattices: 

514 """ Returns the sublattices of the input structure. 

515 

516 Parameters 

517 ---------- 

518 structure 

519 Atomic structure the sublattices are based on. 

520 """ 

521 sl = Sublattices(self.chemical_symbols, 

522 self.primitive_structure, 

523 structure, 

524 fractional_position_tolerance=self.fractional_position_tolerance) 

525 return sl 

526 

527 def assert_structure_compatibility(self, structure: Atoms, vol_tol: float = 1e-5) -> None: 

528 """ Raises error if structure is not compatible with this cluster space. 

529 

530 Parameters 

531 ---------- 

532 structure 

533 Structure to check for compatibility with cluster space. 

534 vol_tol 

535 Tolerance imposed when comparing volumes. 

536 """ 

537 # check volume 

538 vol1 = self.primitive_structure.get_volume() / len(self.primitive_structure) 

539 vol2 = structure.get_volume() / len(structure) 

540 if abs(vol1 - vol2) > vol_tol: 

541 raise ValueError(f'Volume per atom of structure ({vol1}) does not match the volume of' 

542 f' the primitive structure ({vol2}; vol_tol= {vol_tol}).') 

543 

544 # check occupations 

545 sublattices = self.get_sublattices(structure) 

546 sublattices.assert_occupation_is_allowed(structure.get_chemical_symbols()) 

547 

548 # check pbc 

549 if not all(structure.pbc): 

550 raise ValueError('Input structure must be periodic.') 

551 

552 def merge_orbits(self, 

553 equivalent_orbits: dict[int, list[int]], 

554 ignore_permutations: bool = False) -> None: 

555 """ Combines several orbits into one. This allows one to make custom 

556 cluster spaces by manually declaring the clusters in two or more 

557 orbits to be equivalent. This is a powerful approach for simplifying 

558 the cluster spaces of low-dimensional structures such as 

559 surfaces or nanoparticles. 

560 

561 The procedure works in principle for any number of components. Note, 

562 however, that in the case of more than two components the outcome of 

563 the merging procedure inherits the treatment of the multi-component 

564 vectors of the orbit chosen as the representative one. 

565 

566 Parameters 

567 ---------- 

568 equivalent_orbits 

569 The keys of this dictionary denote the indices of the orbit into 

570 which to merge. The values are the indices of the orbits that are 

571 supposed to be merged into the orbit denoted by the key. 

572 ignore_permutations 

573 If ``True`` orbits will be merged even if their multi-component 

574 vectors and/or site permutations differ. While the object will 

575 still be functional, the cluster space may not be properly spanned 

576 by the resulting cluster vectors. 

577 

578 Note 

579 ---- 

580 The orbit index should not be confused with the index shown when 

581 printing the cluster space. 

582 

583 Examples 

584 -------- 

585 The following snippet illustrates the use of this method to create a 

586 cluster space for a (111) FCC surface, in which only the singlets for 

587 the first and second layer are distinct as well as the in-plane pair 

588 interaction in the topmost layer. All other singlets and pairs are 

589 respectively merged into one orbit. After merging there aree only 3 

590 singlets and 2 pairs left with correspondingly higher multiplicities. 

591 

592 >>> from icet import ClusterSpace 

593 >>> from ase.build import fcc111 

594 >>> 

595 >>> # Create primitive surface unit cell 

596 >>> structure = fcc111('Au', size=(1, 1, 8), 

597 ... a=4.1, vacuum=10, periodic=True) 

598 >>> 

599 >>> # Set up initial cluster space 

600 >>> cs = ClusterSpace(structure=structure, 

601 ... cutoffs=[3.8], chemical_symbols=['Au', 'Ag']) 

602 >>> 

603 >>> # At this point, one can inspect the orbits in the cluster space 

604 >>> # by printing the ClusterSpace object and accessing the individial 

605 >>> # orbits. There will be 4 singlets and 8 pairs. 

606 >>> 

607 >>> # Merge singlets for the third and fourth layers as well as all 

608 >>> # pairs except for the one corresponding to the in-plane 

609 >>> # interaction in the topmost surface layer. 

610 >>> cs.merge_orbits({2: [3], 4: [6, 7, 8, 9, 10, 11]}) 

611 """ 

612 

613 self._pruning_history.append(('merge', equivalent_orbits)) 

614 orbits_to_delete = [] 

615 for k1, orbit_indices in equivalent_orbits.items(): 

616 orbit1 = self.orbit_list.get_orbit(k1) 

617 

618 for k2 in orbit_indices: 

619 

620 # sanity checks 

621 if k1 == k2: 

622 raise ValueError(f'Cannot merge orbit {k1} with itself.') 

623 if k2 in orbits_to_delete: 

624 raise ValueError(f'Orbit {k2} cannot be merged into orbit {k1}' 

625 ' since it was already merged with another orbit.') 

626 orbit2 = self.orbit_list.get_orbit(k2) 

627 if orbit1.order != orbit2.order: 

628 raise ValueError(f'The order of orbit {k1} ({orbit1.order}) does not' 

629 f' match the order of orbit {k2} ({orbit2.order}).') 

630 

631 if not ignore_permutations: 

632 # compare site permutations 

633 permutations1 = [el['site_permutations'] 

634 for el in orbit1.cluster_vector_elements] 

635 permutations2 = [el['site_permutations'] 

636 for el in orbit2.cluster_vector_elements] 

637 for vec_group1, vec_group2 in zip(permutations1, permutations2): 

638 if len(vec_group1) != len(vec_group2) or \ 

639 not np.allclose(np.array(vec_group1), np.array(vec_group2)): 

640 raise ValueError(f'Orbit {k1} and orbit {k2} have different ' 

641 'site permutations.') 

642 

643 # compare multi-component vectors (maybe this is redundant because 

644 # site permutations always differ if multi-component vectors differ?) 

645 mc_vectors1 = [el['multicomponent_vector'] 

646 for el in orbit1.cluster_vector_elements] 

647 mc_vectors2 = [el['multicomponent_vector'] 

648 for el in orbit2.cluster_vector_elements] 

649 if not all(np.allclose(vec1, vec2) 649 ↛ 651line 649 didn't jump to line 651 because the condition on line 649 was never true

650 for vec1, vec2 in zip(mc_vectors1, mc_vectors2)): 

651 raise ValueError(f'Orbit {k1} and orbit {k2} have different ' 

652 'multi-component vectors.') 

653 

654 # merge 

655 self._merge_orbit(k1, k2) 

656 orbits_to_delete.append(k2) 

657 

658 # update merge/prune history 

659 self._remove_orbits(orbits_to_delete) 

660 

661 def is_supercell_self_interacting(self, structure: Atoms) -> bool: 

662 """ 

663 Checks whether a structure has self-interactions via periodic 

664 boundary conditions. 

665 Returns ``True`` if the structure contains self-interactions via periodic 

666 boundary conditions, otherwise ``False``. 

667 

668 Parameters 

669 ---------- 

670 structure 

671 Structure to be tested. 

672 """ 

673 ol = self.orbit_list.get_supercell_orbit_list( 

674 structure=structure, 

675 fractional_position_tolerance=self.fractional_position_tolerance) 

676 orbit_indices = set() 

677 for orbit in ol.orbits: 

678 for cluster in orbit.clusters: 

679 indices = tuple(sorted([site.index for site in cluster.lattice_sites])) 

680 if indices in orbit_indices: 

681 return True 

682 else: 

683 orbit_indices.add(indices) 

684 return False 

685 

686 def get_multiplicities(self) -> list[int]: 

687 """ 

688 Get multiplicities for each cluster space element as a list. 

689 """ 

690 return [elem['multiplicity'] for elem in self.as_list] 

691 

692 def write(self, filename: str) -> None: 

693 """ 

694 Saves cluster space to a file. 

695 

696 Parameters 

697 --------- 

698 filename 

699 Name of file to which to write. 

700 """ 

701 

702 with tarfile.open(name=filename, mode='w') as tar_file: 

703 

704 # write items 

705 items = dict(cutoffs=self._cutoffs, 

706 chemical_symbols=self._input_chemical_symbols, 

707 pruning_history=self._pruning_history, 

708 symprec=self.symprec, 

709 position_tolerance=self.position_tolerance) 

710 temp_file = tempfile.TemporaryFile() 

711 pickle.dump(items, temp_file) 

712 temp_file.seek(0) 

713 tar_info = tar_file.gettarinfo(arcname='items', fileobj=temp_file) 

714 tar_file.addfile(tar_info, temp_file) 

715 temp_file.close() 

716 

717 # write structure 

718 temp_file = tempfile.NamedTemporaryFile(delete=False) 

719 temp_file.close() 

720 ase_write(temp_file.name, self._input_structure, format='json') 

721 with open(temp_file.name, 'rb') as tt: 

722 tar_info = tar_file.gettarinfo(arcname='atoms', fileobj=tt) 

723 tar_file.addfile(tar_info, tt) 

724 os.remove(temp_file.name) 

725 

726 @staticmethod 

727 def read(filename: str): 

728 """ 

729 Reads cluster space from file and returns :attr:`ClusterSpace` object. 

730 

731 Parameters 

732 --------- 

733 filename 

734 Name of file from which to read cluster space. 

735 """ 

736 if isinstance(filename, str): 

737 tar_file = tarfile.open(mode='r', name=filename) 

738 else: 

739 tar_file = tarfile.open(mode='r', fileobj=filename) 

740 

741 # read items 

742 items = pickle.load(tar_file.extractfile('items')) 

743 

744 # read structure 

745 temp_file = tempfile.NamedTemporaryFile(delete=False) 

746 temp_file.write(tar_file.extractfile('atoms').read()) 

747 temp_file.close() 

748 structure = ase_read(temp_file.name, format='json') 

749 os.remove(temp_file.name) 

750 

751 tar_file.close() 

752 

753 # ensure backward compatibility 

754 if 'symprec' not in items: # pragma: no cover 

755 items['symprec'] = 1e-5 

756 if 'position_tolerance' not in items: # pragma: no cover 

757 items['position_tolerance'] = items['symprec'] 

758 

759 cs = ClusterSpace(structure=structure, 

760 cutoffs=items['cutoffs'], 

761 chemical_symbols=items['chemical_symbols'], 

762 symprec=items['symprec'], 

763 position_tolerance=items['position_tolerance']) 

764 if len(items['pruning_history']) > 0: 

765 if isinstance(items['pruning_history'][0], tuple): 765 ↛ 774line 765 didn't jump to line 774 because the condition on line 765 was always true

766 for key, value in items['pruning_history']: 

767 if key == 'prune': 

768 cs.prune_orbit_list(value) 

769 elif key == 'merge': 769 ↛ 766line 769 didn't jump to line 766 because the condition on line 769 was always true

770 # It is safe to ignore permutations here because otherwise 

771 # the orbits could not have been merged in the first place. 

772 cs.merge_orbits(value, ignore_permutations=True) 

773 else: # for backwards compatibility 

774 for value in items['pruning_history']: 

775 cs.prune_orbit_list(value) 

776 

777 return cs 

778 

779 def copy(self): 

780 """ Returns copy of :class:`ClusterSpace` instance. """ 

781 cs_copy = ClusterSpace(structure=self._input_structure, 

782 cutoffs=self.cutoffs, 

783 chemical_symbols=self._input_chemical_symbols, 

784 symprec=self.symprec, 

785 position_tolerance=self.position_tolerance) 

786 

787 for key, value in self._pruning_history: 

788 if key == 'prune': 

789 cs_copy.prune_orbit_list(value) 

790 elif key == 'merge': 790 ↛ 787line 790 didn't jump to line 787 because the condition on line 790 was always true

791 # It is safe to ignore permutations here because otherwise 

792 # the orbits could not have been merged in the first place. 

793 cs_copy.merge_orbits(value, ignore_permutations=True) 

794 return cs_copy