Coverage for mchammer/observers/cluster_count_observer.py: 100%
48 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-09-14 04:08 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2025-09-14 04:08 +0000
1from collections.abc import Iterable
3import pandas as pd
5from ase import Atoms
6from icet import ClusterSpace
7from icet.core.local_orbit_list_generator import LocalOrbitListGenerator
8from icet.core.structure import Structure
9from icet.tools.geometry import chemical_symbols_to_numbers
10from mchammer.observers.base_observer import BaseObserver
13class ClusterCountObserver(BaseObserver):
14 """This class represents a cluster count observer.
16 A cluster count observer enables one to keep track of the
17 occupation of clusters along the trajectory sampled by a Monte
18 Carlo (MC) simulation. For example, given several canonical MC
19 simulations representing different temperatures, this observer can
20 be used to access the temperature dependence of the number of
21 nearest neigbhors of a particular species.
23 The outputted cluster counts come in a dictionary format which keys like
24 ``0_Al``, ``0_Cu``, ``1_Al_Al``, ``1_Al_Cu``, ``1_Cu_Al``, ``1_Cu_Cu``, etc,
25 where the number indicates the orbit index and the symbols indicate the elements for this orbit.
26 The values of the dictionary contains the number of the such clusters found in the structure.
28 Parameters
29 ----------
30 cluster_space
31 Cluster space to define the clusters to be counted.
32 structure
33 Defines the lattice that the observer will work on.
34 orbit_indices
35 Only include orbits up to the orbit with this index.
36 By default all orbits are included.
37 interval
38 Observation interval. Defaults to ``None`` meaning that if the
39 observer is used in a Monte Carlo simulations, then the :class:`Ensemble` object
40 will determine the interval.
41 """
43 def __init__(self, cluster_space: ClusterSpace,
44 structure: Atoms,
45 interval: int = None,
46 orbit_indices: list[int] = None) -> None:
47 super().__init__(interval=interval, return_type=dict, tag='ClusterCountObserver')
49 self._cluster_space = cluster_space
50 local_orbit_list_generator = LocalOrbitListGenerator(
51 orbit_list=cluster_space.orbit_list,
52 structure=Structure.from_atoms(structure),
53 fractional_position_tolerance=cluster_space.fractional_position_tolerance)
55 self._full_orbit_list = local_orbit_list_generator.generate_full_orbit_list()
57 if orbit_indices is None:
58 self._orbit_indices = list(range(len(self._full_orbit_list)))
59 elif not isinstance(orbit_indices, Iterable):
60 raise ValueError('Argument orbit_indices should be a list of integers, '
61 f'not {type(orbit_indices)}')
62 else:
63 self._orbit_indices = orbit_indices
65 self._possible_occupations = self._get_possible_occupations()
67 def _get_possible_occupations(self) -> dict[int, list[tuple[str]]]:
68 """ Returns a dictionary containing the possible occupations for each orbit. """
69 possible_occupations = {}
70 for i in self._orbit_indices:
71 possible_occupations_orbit = self._cluster_space.get_possible_orbit_occupations(i)
72 order = self._full_orbit_list.get_orbit(i).order
73 assert order == len(possible_occupations_orbit[0]), \
74 f'Order (n={order}) does not match possible occupations' \
75 f' (n={len(possible_occupations[0])}, {possible_occupations}).'
76 possible_occupations[i] = possible_occupations_orbit
77 return possible_occupations
79 def get_cluster_counts(self, structure: Atoms) -> pd.DataFrame:
80 """Counts the number of times different clusters appear in the structure
81 and returns this information as a pandas dataframe.
83 Parameters
84 ----------
85 structure
86 input atomic structure.
87 """
88 rows = []
89 structure_icet = Structure.from_atoms(structure)
90 for i in self._orbit_indices:
91 orbit = self._full_orbit_list.get_orbit(i)
92 cluster_counts = orbit.get_cluster_counts(structure_icet)
93 for chemical_symbols in self._possible_occupations[i]:
94 count = cluster_counts.get(tuple(chemical_symbols_to_numbers(chemical_symbols)), 0)
95 row = {}
96 row['dc_tag'] = '{}_{}'.format(i, '_'.join(chemical_symbols))
97 row['occupation'] = chemical_symbols
98 row['cluster_count'] = count
99 row['orbit_index'] = i
100 row['order'] = orbit.order
101 rows.append(row)
102 return pd.DataFrame(rows)
104 def get_observable(self, structure: Atoms) -> dict:
105 """
106 Returns the value of the property from a cluster expansion model
107 for a given atomic configuration.
109 Parameters
110 ----------
111 structure
112 input atomic structure
113 """
114 counts = self.get_cluster_counts(structure)
115 count_dict = {row['dc_tag']: row['cluster_count']
116 for i, row in counts.iterrows()}
117 return count_dict