Coverage for icet/core/cluster_expansion.py: 97%
195 statements
« prev ^ index » next coverage.py v7.5.0, created at 2025-03-09 04:14 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2025-03-09 04:14 +0000
1"""
2This module provides the ClusterExpansion class.
3"""
5import os
6import pandas as pd
7import numpy as np
8import pickle
9import tempfile
10import tarfile
11import re
13from icet import ClusterSpace
14from icet.core.structure import Structure
15from typing import List, Union
16from ase import Atoms
19class ClusterExpansion:
20 """Cluster expansions are obtained by combining a cluster space with a set
21 of parameters, where the latter is commonly obtained by optimization.
22 Instances of this class allow one to predict the property of interest for
23 a given structure.
25 Note
26 ----
27 Each element of the parameter vector corresponds to an effective cluster
28 interaction (ECI) multiplied by the multiplicity of the underlying orbit.
30 Attributes
31 ----------
32 cluster_space
33 Cluster space that was used for constructing the cluster expansion.
34 parameters
35 Parameter vector.
36 metadata
37 Metadata dictionary, user-defined metadata to be stored together
38 with cluster expansion. Will be pickled when CE is written to file.
39 By default contains icet version, username, hostname and date.
41 Raises
42 ------
43 ValueError
44 If :attr:`cluster_space` and :attr:`parameters` differ in length.
46 Example
47 -------
48 The following snippet illustrates the initialization and usage of a
49 :class:`ClusterExpansion` object. Here, the parameters are taken to be
50 a list of ones. Usually, they would be obtained by training with
51 respect to a set of reference data::
53 >>> from ase.build import bulk
54 >>> from icet import ClusterSpace, ClusterExpansion
56 >>> # create cluster expansion with fake parameters
57 >>> prim = bulk('Au')
58 >>> cs = ClusterSpace(prim, cutoffs=[7.0, 5.0],
59 ... chemical_symbols=[['Au', 'Pd']])
60 >>> parameters = len(cs) * [1.0]
61 >>> ce = ClusterExpansion(cs, parameters)
63 >>> # make prediction for supercell
64 >>> sc = prim.repeat(3)
65 >>> for k in [1, 4, 7]:
66 >>> sc[k].symbol = 'Pd'
67 >>> print(ce.predict(sc))
68 """
70 def __init__(self, cluster_space: ClusterSpace, parameters: np.array,
71 metadata: dict = None) -> None:
72 if len(cluster_space) != len(parameters):
73 raise ValueError('cluster_space ({}) and parameters ({}) must have'
74 ' the same length'.format(len(cluster_space), len(parameters)))
75 self._cluster_space = cluster_space.copy()
76 if isinstance(parameters, list):
77 parameters = np.array(parameters)
78 self._parameters = parameters
80 # add metadata
81 if metadata is None:
82 metadata = dict()
83 self._metadata = metadata
84 self._add_default_metadata()
86 def predict(self, structure: Union[Atoms, Structure]) -> float:
87 """
88 Returns the property value predicted by the cluster expansion.
90 Parameters
91 ----------
92 structure
93 Atomic configuration.
94 """
95 cluster_vector = self._cluster_space.get_cluster_vector(structure)
96 prop = np.dot(cluster_vector, self.parameters)
97 return prop
99 def get_cluster_space_copy(self) -> ClusterSpace:
100 """ Returns copy of cluster space on which cluster expansion is based. """
101 return self._cluster_space.copy()
103 def to_dataframe(self) -> pd.DataFrame:
104 """Returns a representation of the cluster expansion in the form of a
105 DataFrame including effective cluster interactions (ECIs)."""
106 rows = self._cluster_space.as_list
107 for row, param in zip(rows, self.parameters):
108 row['parameter'] = param
109 row['eci'] = param / row['multiplicity']
110 df = pd.DataFrame(rows)
111 del df['index']
112 return df
114 @property
115 def chemical_symbols(self) -> List[List[str]]:
116 """ Species identified by their chemical symbols (copy). """
117 return self._cluster_space.chemical_symbols.copy()
119 @property
120 def cutoffs(self) -> List[float]:
121 """
122 Cutoffs for different n-body clusters (copy). The cutoff radius (in
123 Ångstroms) defines the largest interatomic distance in a
124 cluster.
125 """
126 return self._cluster_space.cutoffs.copy()
128 @property
129 def orders(self) -> List[int]:
130 """ Orders included in cluster expansion. """
131 return list(range(len(self._cluster_space.cutoffs) + 2))
133 @property
134 def parameters(self) -> List[float]:
135 """ Parameter vector. Each element of the parameter vector corresponds
136 to an effective cluster interaction (ECI) multiplied by the
137 multiplicity of the respective orbit. """
138 return self._parameters
140 @property
141 def metadata(self) -> dict:
142 """ Metadata associated with the cluster expansion. """
143 return self._metadata
145 @property
146 def symprec(self) -> float:
147 """ Tolerance imposed when analyzing the symmetry using spglib
148 (inherited from the underlying cluster space). """
149 return self._cluster_space.symprec
151 @property
152 def position_tolerance(self) -> float:
153 """ Tolerance applied when comparing positions in Cartesian coordinates
154 (inherited from the underlying cluster space). """
155 return self._cluster_space.position_tolerance
157 @property
158 def fractional_position_tolerance(self) -> float:
159 """ Tolerance applied when comparing positions in fractional coordinates
160 (inherited from the underlying cluster space). """
161 return self._cluster_space.fractional_position_tolerance
163 @property
164 def primitive_structure(self) -> Atoms:
165 """ Primitive structure on which cluster expansion is based. """
166 return self._cluster_space.primitive_structure.copy()
168 def __len__(self) -> int:
169 return len(self._parameters)
171 def _get_string_representation(self, print_threshold: int = None,
172 print_minimum: int = 10):
173 """ String representation of the cluster expansion. """
174 cluster_space_repr = self._cluster_space._get_string_representation(
175 print_threshold, print_minimum).split('\n')
176 # rescale width
177 par_col_width = max(len('{:9.3g}'.format(max(self._parameters, key=abs))), len('ECI'))
178 width = len(cluster_space_repr[0]) + 2 * (len(' | ') + par_col_width)
180 s = []
181 s += ['{s:=^{n}}'.format(s=' Cluster Expansion ', n=width)]
182 s += [t for t in cluster_space_repr if re.search(':', t)]
184 # additional information about number of nonzero parameters
185 df = self.to_dataframe()
186 orders = self.orders
187 nzp_by_order = [np.count_nonzero(df[df.order == order].eci) for order in orders]
188 assert sum(nzp_by_order) == np.count_nonzero(self.parameters)
189 s += [' {:38} : {}'.format('total number of nonzero parameters', sum(nzp_by_order))]
190 line = ' {:38} :'.format('number of nonzero parameters by order')
191 for order, nzp in zip(orders, nzp_by_order):
192 line += ' {}= {} '.format(order, nzp)
193 s += [line]
195 # table header
196 s += [''.center(width, '-')]
197 t = [t for t in cluster_space_repr if 'index' in t]
198 t += ['{s:^{n}}'.format(s='parameter', n=par_col_width)]
199 t += ['{s:^{n}}'.format(s='ECI', n=par_col_width)]
200 s += [' | '.join(t)]
201 s += [''.center(width, '-')]
203 # table body
204 index = 0
205 while index < len(self):
206 if (print_threshold is not None and
207 len(self) > print_threshold and
208 index >= print_minimum and
209 index <= len(self) - print_minimum):
210 index = len(self) - print_minimum
211 s += [' ...']
212 pattern = r'^{:4}'.format(index)
213 t = [t for t in cluster_space_repr if re.match(pattern, t)]
214 parameter = self._parameters[index]
215 t += ['{s:^{n}}'.format(s=f'{parameter:9.3g}', n=par_col_width)]
216 eci = parameter / self._cluster_space.as_list[index]['multiplicity']
217 t += ['{s:^{n}}'.format(s=f'{eci:9.3g}', n=par_col_width)]
218 s += [' | '.join(t)]
219 index += 1
220 s += [''.center(width, '=')]
222 return '\n'.join(s)
224 def __str__(self) -> str:
225 """ String representation. """
226 return self._get_string_representation(print_threshold=50)
228 def _repr_html_(self) -> str:
229 """ HTML representation. Used, e.g., in jupyter notebooks. """
230 s = ['<h4>Cluster Expansion</h4>']
231 s += ['<table border="1" class="dataframe">']
232 s += ['<thead><tr><th style="text-align: left;">Field</th><th>Value</th></tr></thead>']
233 s += ['<tbody>']
234 s += ['<tr><td style="text-align: left;">Space group</td>'
235 f'<td>{self._cluster_space.space_group}</td></tr>']
236 for sl in self._cluster_space.get_sublattices(
237 self.primitive_structure).active_sublattices:
238 s += [f'<tr><td style="text-align: left;">Sublattice {sl.symbol}</td>'
239 f'<td>{sl.chemical_symbols}</td></tr>']
240 s += ['<tr><td style="text-align: left;">Cutoffs</td>'
241 f'<td>{self._cluster_space.cutoffs}</td></tr>']
243 df = self.to_dataframe()
244 nzp_by_order = [np.count_nonzero(df[df.order == order].eci) for order in self.orders]
245 assert sum(nzp_by_order) == np.count_nonzero(self.parameters)
246 s += ['<tr><td style="text-align: left;">Total number of parameters (nonzero)</td>'
247 f'<td>{len(self)} ({sum(nzp_by_order)})</td></tr>']
248 for (order, npar), nzp in zip(
249 self._cluster_space.number_of_orbits_by_order.items(), nzp_by_order):
250 s += ['<tr><td style="text-align: left;">'
251 f'Number of parameters of order {order} (nonzero)</td>'
252 f'<td>{npar} ({nzp})</td></tr>']
253 s += ['<tr><td style="text-align: left;">fractional_position_tolerance</td>'
254 f'<td>{self._cluster_space.fractional_position_tolerance}</td></tr>']
255 s += ['<tr><td style="text-align: left;">position_tolerance</td>'
256 f'<td>{self._cluster_space.position_tolerance}</td></tr>']
257 s += ['<tr><td style="text-align: left;">symprec</td>'
258 f'<td>{self._cluster_space.symprec}</td></tr>']
260 s += ['</tbody>']
261 s += ['</table>']
262 return ''.join(s)
264 def __repr__(self) -> str:
265 """ Representation. """
266 s = type(self).__name__ + '('
267 s += f'cluster_space={self._cluster_space.__repr__()}'
268 s += f', parameters={list(self._parameters).__repr__()}'
269 s += ')'
270 return s
272 def prune(self, indices: List[int] = None, tol: float = 0) -> None:
273 """Removes orbits from the cluster expansion, for which the absolute
274 values of the corresponding parameters are zero or close to
275 zero. This commonly reduces the computational cost for
276 evaluating the cluster expansion. It is therefore recommended
277 to apply this method prior to using the cluster expansion in
278 production. If the method is called without arguments only
279 orbits will be pruned, for which the ECIs are strictly zero.
280 Less restrictive pruning can be achieved by setting the
281 :attr:`tol` keyword.
283 Parameters
284 ----------
285 indices
286 Indices of parameters to remove from the cluster expansion.
287 tol
288 All orbits will be pruned for which the absolute parameter value(s)
289 is/are within this tolerance.
290 """
292 # find orbit indices to be removed
293 if indices is None:
294 indices = [i for i, param in enumerate(
295 self.parameters) if np.abs(param) <= tol and i > 0]
296 df = self.to_dataframe()
297 indices = list(set(indices))
299 if 0 in indices:
300 raise ValueError('Orbit index cannot be 0 since the zerolet may not be pruned.')
301 orbit_candidates_for_removal = df.orbit_index[np.array(indices)].tolist()
302 safe_to_remove_orbits, safe_to_remove_params = [], []
303 for oi in set(orbit_candidates_for_removal):
304 if oi == -1: 304 ↛ 305line 304 didn't jump to line 305, because the condition on line 304 was never true
305 continue
306 orbit_count = df.orbit_index.tolist().count(oi)
307 oi_remove_count = orbit_candidates_for_removal.count(oi)
308 if orbit_count <= oi_remove_count:
309 safe_to_remove_orbits.append(oi)
310 safe_to_remove_params += df.index[df['orbit_index'] == oi].tolist()
312 # prune cluster space
313 self._cluster_space.prune_orbit_list(indices=safe_to_remove_orbits)
314 self._parameters = self._parameters[np.setdiff1d(
315 np.arange(len(self._parameters)), safe_to_remove_params)]
316 assert len(self._parameters) == len(self._cluster_space)
318 def write(self, filename: str) -> None:
319 """
320 Writes ClusterExpansion object to file.
322 Parameters
323 ---------
324 filename
325 name of file to which to write
326 """
327 self._cluster_space.write(filename)
329 items = dict()
330 items['parameters'] = self.parameters
332 # TODO: remove if condition once metadata is firmly established
333 if hasattr(self, '_metadata'): 333 ↛ 336line 333 didn't jump to line 336, because the condition on line 333 was never false
334 items['metadata'] = self._metadata
336 with tarfile.open(name=filename, mode='w') as tar_file:
337 cs_file = tempfile.NamedTemporaryFile(delete=False)
338 cs_file.close()
339 self._cluster_space.write(cs_file.name)
340 tar_file.add(cs_file.name, arcname='cluster_space')
342 # write items
343 temp_file = tempfile.TemporaryFile()
344 pickle.dump(items, temp_file)
345 temp_file.seek(0)
346 tar_info = tar_file.gettarinfo(arcname='items', fileobj=temp_file)
347 tar_file.addfile(tar_info, temp_file)
348 os.remove(cs_file.name)
349 temp_file.close()
351 @staticmethod
352 def read(filename: str):
353 """
354 Reads :class:`ClusterExpansion` object from file.
356 Parameters
357 ---------
358 filename
359 File from which to read.
360 """
361 with tarfile.open(name=filename, mode='r') as tar_file:
362 cs_file = tempfile.NamedTemporaryFile(delete=False)
363 cs_file.write(tar_file.extractfile('cluster_space').read())
364 cs_file.close()
365 cs = ClusterSpace.read(cs_file.name)
366 items = pickle.load(tar_file.extractfile('items'))
367 os.remove(cs_file.name)
369 ce = ClusterExpansion.__new__(ClusterExpansion)
370 ce._cluster_space = cs
371 ce._parameters = items['parameters']
373 # TODO: remove if condition once metadata is firmly established
374 if 'metadata' in items: 374 ↛ 377line 374 didn't jump to line 377, because the condition on line 374 was never false
375 ce._metadata = items['metadata']
377 assert list(items['parameters']) == list(ce.parameters)
378 return ce
380 def _add_default_metadata(self):
381 """ Adds default metadata to metadata dict. """
382 import getpass
383 import socket
384 from datetime import datetime
385 from icet import __version__ as icet_version
387 self._metadata['date_created'] = datetime.now().strftime('%Y-%m-%dT%H:%M:%S')
388 self._metadata['username'] = getpass.getuser()
389 self._metadata['hostname'] = socket.gethostname()
390 self._metadata['icet_version'] = icet_version