Coverage for icet/core/structure_container.py: 97%
220 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-12-26 04:12 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-12-26 04:12 +0000
1"""
2This module provides the StructureContainer class.
3"""
4import tarfile
5import tempfile
7from typing import BinaryIO, List, TextIO, Tuple, Union
9import numpy as np
10import ase.db
11from ase import Atoms
13from icet import ClusterSpace
14from icet.input_output.logging_tools import logger
15from pandas import DataFrame
16logger = logger.getChild('structure_container')
19class StructureContainer:
20 """This class serves as a container for structure objects as well as their
21 properties and cluster vectors.
23 Parameters
24 ----------
25 cluster_space
26 Cluster space used for evaluating the cluster vectors.
28 Example
29 -------
30 The following snippet illustrates the initialization and usage of
31 a :class:`StructureContainer` object. A structure container
32 provides convenient means for compiling the data needed to train a
33 cluster expansion, i.e., a sensing matrix and target property values::
35 >>> from ase.build import bulk
36 >>> from icet import ClusterSpace, StructureContainer
37 >>> from icet.tools import enumerate_structures
38 >>> from random import random
40 >>> # create cluster space
41 >>> prim = bulk('Au')
42 >>> cs = ClusterSpace(prim, cutoffs=[7.0, 5.0],
43 ... chemical_symbols=[['Au', 'Pd']])
45 >>> # build structure container
46 >>> sc = StructureContainer(cs)
47 >>> for structure in enumerate_structures(prim, range(5), ['Au', 'Pd']):
48 >>> sc.add_structure(structure,
49 ... properties={'my_random_energy': random()})
50 >>> print(sc)
52 >>> # fetch sensing matrix and target energies
53 >>> A, y = sc.get_fit_data(key='my_random_energy')
55 """
57 def __init__(self, cluster_space: ClusterSpace):
59 if not isinstance(cluster_space, ClusterSpace):
60 raise TypeError('cluster_space must be a ClusterSpace object.')
62 self._cluster_space = cluster_space
63 self._structure_list = []
65 def __len__(self) -> int:
66 return len(self._structure_list)
68 def __getitem__(self, ind: int):
69 return self._structure_list[ind]
71 def get_structure_indices(self, user_tag: str = None) -> List[int]:
72 """
73 Returns indices of structures with the given user tag. This
74 method provides a simple means for filtering structures. The
75 :attr:`user_tag` is assigned when adding structures via the
76 :func:`add_structure` method.
78 Parameters
79 ----------
80 user_tag
81 The indices of structures with this user tag are returned.
83 Returns
84 -------
85 List of structure indices.
86 """
87 return [i for i, s in enumerate(self) if user_tag is None or s.user_tag == user_tag]
89 def _get_string_representation(self, print_threshold: int = None) -> str:
90 """
91 String representation of the structure container that provides an
92 overview of the structures in the container.
94 Parameters
95 ----------
96 print_threshold
97 If the number of structures exceeds this number print dots.
99 Returns
100 -------
101 String representation of the structure container.
102 """
104 if len(self) == 0:
105 return 'Empty StructureContainer'
107 # Number of structures to print before cutting and printing dots
108 if print_threshold is None or print_threshold >= len(self):
109 print_threshold = len(self) + 2
111 # format specifiers for fields in table
112 def get_format(val):
113 if isinstance(val, float):
114 return '{:9.4f}'
115 else:
116 return '{}'
118 # table headers
119 default_headers = ['index', 'user_tag', 'n_atoms', 'chemical formula']
120 property_headers = sorted(set(key for fs in self for key in fs.properties))
121 headers = default_headers + property_headers
123 # collect the table data
124 str_table = []
125 for i, fs in enumerate(self):
126 default_data = [i, fs.user_tag, len(fs), fs.structure.get_chemical_formula()]
127 property_data = [fs.properties.get(key, '') for key in property_headers]
128 str_row = [get_format(d).format(d) for d in default_data+property_data]
129 str_table.append(str_row)
130 str_table = np.array(str_table)
132 # find maximum widths for each column
133 widths = []
134 for i in range(str_table.shape[1]):
135 data_width = max(len(val) for val in str_table[:, i])
136 header_width = len(headers[i])
137 widths.append(max([data_width, header_width]))
139 total_width = sum(widths) + 3 * len(headers)
140 row_format = ' | '.join('{:'+str(width)+'}' for width in widths)
142 # Make string representation of table
143 s = []
144 s += ['{s:=^{n}}'.format(s=' Structure Container ', n=total_width)]
145 s += ['Total number of structures: {}'.format(len(self))]
146 s += [''.center(total_width, '-')]
147 s += [row_format.format(*headers)]
148 s += [''.center(total_width, '-')]
149 for i, fs_data in enumerate(str_table, start=1):
150 s += [row_format.format(*fs_data)]
151 if i+1 >= print_threshold:
152 s += [' ...']
153 s += [row_format.format(*str_table[-1])]
154 break
155 s += [''.center(total_width, '=')]
156 s = '\n'.join(s)
158 return s
160 def __str__(self) -> str:
161 """ String representation. """
162 return self._get_string_representation(print_threshold=50)
164 def _repr_html_(self) -> str:
165 """ HTML representation. Used, e.g., in jupyter notebooks. """
166 s = ['<h4>Structure Container</h4>']
167 s += [f'<p>Total number of structures: {len(self)}</p>']
168 s += self.to_dataframe()._repr_html_()
169 return ''.join(s)
171 def to_dataframe(self) -> DataFrame:
172 """Summary of :class:`StructureContainer` object in :class:`DataFrame
173 <pandas.DataFrame>` format.
174 """
175 data = []
176 for s in self:
177 record = dict(
178 user_tag=s.user_tag,
179 natoms=len(s),
180 formula=s.structure.get_chemical_formula('metal'),
181 )
182 record.update(s.properties)
183 data.append(record)
184 return DataFrame.from_dict(data)
186 def add_structure(self, structure: Atoms, user_tag: str = None,
187 properties: dict = None, allow_duplicate: bool = True,
188 sanity_check: bool = True):
189 """
190 Adds a structure to the structure container.
192 Parameters
193 ----------
194 structure
195 Atomic structure to be added.
196 user_tag
197 User tag for labeling structure.
198 properties
199 Scalar properties. If properties are not specified the structure
200 object will be checked for an attached ASE calculator object
201 with a calculated potential energy.
202 allow_duplicate
203 Whether or not to add the structure if there already exists a
204 structure with identical cluster vector.
205 sanity_check
206 Whether or not to carry out a sanity check before adding the
207 structure. This includes checking occupations and volume.
208 """
210 # structure must have a proper format and label
211 if not isinstance(structure, Atoms): 211 ↛ 212line 211 didn't jump to line 212, because the condition on line 211 was never true
212 raise TypeError(f'structure must be an ASE Atoms object not {type(structure)}')
214 if user_tag is not None:
215 if not isinstance(user_tag, str): 215 ↛ 216line 215 didn't jump to line 216, because the condition on line 215 was never true
216 raise TypeError(f'user_tag must be a string not {type(user_tag)}.')
218 if sanity_check: 218 ↛ 222line 218 didn't jump to line 222, because the condition on line 218 was never false
219 self._cluster_space.assert_structure_compatibility(structure)
221 # check for properties in attached calculator
222 if properties is None:
223 properties = {}
224 if structure.calc is not None: 224 ↛ 230line 224 didn't jump to line 230, because the condition on line 224 was never false
225 if not structure.calc.calculation_required(structure, ['energy']):
226 energy = structure.get_potential_energy()
227 properties['energy'] = energy / len(structure)
229 # check if there exist structures with identical cluster vectors
230 structure_copy = structure.copy()
231 cv = self._cluster_space.get_cluster_vector(structure_copy)
232 if not allow_duplicate:
233 for i, fs in enumerate(self):
234 if np.allclose(cv, fs.cluster_vector):
235 msg = '{} and {} have identical cluster vectors'.format(
236 user_tag if user_tag is not None else 'Input structure',
237 fs.user_tag if fs.user_tag != 'None' else 'structure')
238 msg += ' at index {}'.format(i)
239 raise ValueError(msg)
241 # add structure
242 structure = FitStructure(structure_copy, user_tag, cv, properties)
243 self._structure_list.append(structure)
245 def get_condition_number(self, structure_indices: List[int] = None) -> float:
246 """Returns the condition number for the sensing matrix.
248 A very large condition number can be a sign of
249 multicollinearity. More information can be found
250 [here](https://en.wikipedia.org/wiki/Condition_number).
252 Parameters
253 ----------
254 structure_indices
255 List of structure indices to include. By default (``None``) the
256 method will return all fit data available.
258 Returns
259 -------
260 Condition number of the sensing matrix.
261 """
262 return np.linalg.cond(self.get_fit_data(structure_indices, key=None)[0])
264 def get_fit_data(self, structure_indices: List[int] = None,
265 key: str = 'energy') -> Tuple[np.ndarray, np.ndarray]:
266 """
267 Returns fit data for all structures. The cluster vectors and
268 target properties for all structures are stacked into numpy arrays.
270 Parameters
271 ----------
272 structure_indices
273 List of structure indices. By default (``None``) the
274 method will return all fit data available.
275 key
276 Name of property to use. If ``None`` do not include property values.
277 This can be useful if only the fit matrix is needed.
279 Returns
280 -------
281 Cluster vectors and target properties for desired structures.
282 """
283 if structure_indices is None:
284 cv_list = [s.cluster_vector for s in self._structure_list]
285 if key is None:
286 prop_list = None
287 else:
288 prop_list = [s.properties[key] for s in self._structure_list]
290 else:
291 cv_list, prop_list = [], []
292 for i in structure_indices:
293 cv_list.append(self._structure_list[i].cluster_vector)
294 if key is None: 294 ↛ 295line 294 didn't jump to line 295, because the condition on line 294 was never true
295 prop_list = None
296 else:
297 prop_list.append(self._structure_list[i].properties[key])
299 if cv_list is None:
300 raise Exception(f'No available fit data for {structure_indices}.')
302 cv_list = np.array(cv_list)
303 if key is not None:
304 prop_list = np.array(prop_list)
306 return cv_list, prop_list
308 @property
309 def cluster_space(self) -> ClusterSpace:
310 """ Cluster space used to calculate the cluster vectors. """
311 return self._cluster_space
313 @property
314 def available_properties(self) -> List[str]:
315 """ List of the available properties. """
316 return sorted(set([p for fs in self for p in fs.properties.keys()]))
318 def write(self, outfile: Union[str, BinaryIO, TextIO]) -> None:
319 """
320 Writes structure container to a file.
322 Parameters
323 ----------
324 outfile
325 Output file name or file object.
326 """
327 # Write cluster space to tempfile
328 temp_cs_file = tempfile.NamedTemporaryFile(delete=False)
329 self.cluster_space.write(temp_cs_file.name)
331 # Write fit structures as an ASE db in tempfile
332 temp_db_file = tempfile.NamedTemporaryFile(delete=False)
333 temp_db_file.close()
334 if self._structure_list:
335 db = ase.db.connect(temp_db_file.name, type='db', append=False)
337 for fit_structure in self._structure_list:
338 data_dict = {'user_tag': fit_structure.user_tag,
339 'properties': fit_structure.properties,
340 'cluster_vector': fit_structure.cluster_vector}
341 db.write(fit_structure.structure, data=data_dict)
343 with tarfile.open(outfile, mode='w') as handle:
344 handle.add(temp_db_file.name, arcname='database')
345 handle.add(temp_cs_file.name, arcname='cluster_space')
347 @staticmethod
348 def read(infile: Union[str, BinaryIO, TextIO]):
349 """
350 Reads :class:`StructureContainer` object from file.
352 Parameters
353 ----------
354 infile
355 File from which to read.
356 """
357 if isinstance(infile, str):
358 filename = infile
359 else:
360 filename = infile.name
362 if not tarfile.is_tarfile(filename):
363 raise TypeError('{} is not a tar file'.format(filename))
365 temp_db_file = tempfile.NamedTemporaryFile(delete=False)
366 with tarfile.open(mode='r', name=filename) as tar_file:
367 cs_file = tar_file.extractfile('cluster_space')
368 temp_db_file.write(tar_file.extractfile('database').read())
369 temp_db_file.seek(0)
370 cluster_space = ClusterSpace.read(cs_file)
371 database = ase.db.connect(temp_db_file.name, type='db')
373 structure_container = StructureContainer(cluster_space)
374 fit_structures = []
375 for row in database.select():
376 data = row.data
377 fit_structure = FitStructure(row.toatoms(),
378 user_tag=data['user_tag'],
379 cluster_vector=data['cluster_vector'],
380 properties=data['properties'])
381 fit_structures.append(fit_structure)
382 structure_container._structure_list = fit_structures
383 return structure_container
386class FitStructure:
387 """
388 This class holds a supercell along with its properties and cluster
389 vector.
391 Attributes
392 ----------
393 structure
394 Supercell structure.
395 user_tag
396 Custom user tag.
397 cluster_vector
398 Cluster vector.
399 properties
400 Dictionary comprising name and value of properties.
401 """
403 def __init__(self, structure: Atoms, user_tag: str,
404 cluster_vector: np.ndarray, properties: dict = {}):
405 self._structure = structure
406 self._user_tag = user_tag
407 self._cluster_vector = cluster_vector
408 self.properties = properties
410 @property
411 def cluster_vector(self) -> np.ndarray:
412 """ Cluster vector. """
413 return self._cluster_vector
415 @property
416 def structure(self) -> Atoms:
417 """ Atomic structure. """
418 return self._structure
420 @property
421 def user_tag(self) -> str:
422 """ Structure label. """
423 return str(self._user_tag)
425 def __getattr__(self, key):
426 """ Accesses properties if possible and returns value. """
427 if key not in self.properties.keys():
428 return super().__getattribute__(key)
429 return self.properties[key]
431 def __len__(self) -> int:
432 """ Number of sites in the structure. """
433 return len(self._structure)
435 def __str__(self) -> str:
436 width = 50
437 s = []
438 s += ['{s:=^{n}}'.format(s=' Fit Structure ', n=width)]
439 s += [' {:22} : {}'.format('user tag', self.user_tag)]
440 for k, v in self.properties.items():
441 s += [f' {k:22} : {v}']
442 t = 'cell metric'
443 for k, row in enumerate(self.structure.cell[:]):
444 s += [f' {t:22} : {row}']
445 t = ''
446 t = 'sites'
447 for site in self.structure:
448 s += [f' {t:22} : {site.index} {site.symbol:2} {site.position}']
449 t = ''
450 s += [''.center(width, '=')]
451 return '\n'.join(s)
453 def _repr_html_(self) -> str:
454 s = ['<h4>FitStructure</h4>']
455 s += ['<table border="1" class="dataframe">']
456 s += ['<thead><tr><th style="text-align: left;">Property</th><th>Value</th></tr></thead>']
457 s += ['<tbody>']
458 s += [f'<tr><td style="text-align: left;">user tag</td><td>{self.user_tag}</td></tr>']
459 for key, value in sorted(self.properties.items()):
460 s += [f'<tr><td style="text-align: left;">{key}</td><td>{value}</td></tr>']
461 s += ['</tbody></table>']
463 s += ['<table border="1" class="dataframe">']
464 s += ['<thead><tr><th style="text-align: left;">Cell</th></tr></thead>']
465 s += ['<tbody>']
466 for row in self.structure.cell[:]:
467 s += ['<tr>']
468 for c in row:
469 s += [f'<td>{c}</td>']
470 s += ['</tr>']
471 s += ['</tbody></table>']
473 df = DataFrame(np.array([self.structure.symbols,
474 self.structure.positions[:, 0],
475 self.structure.positions[:, 1],
476 self.structure.positions[:, 2]]).T,
477 columns=['Species', 'Position x', 'Position y', 'Position z'])
478 s += df._repr_html_()
479 return ''.join(s)