Coverage for mchammer/observers/cluster_count_observer.py: 100%
49 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-12-26 04:12 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-12-26 04:12 +0000
1from typing import Dict, List, Tuple
2from collections.abc import Iterable
4import pandas as pd
6from ase import Atoms
7from icet import ClusterSpace
8from icet.core.local_orbit_list_generator import LocalOrbitListGenerator
9from icet.core.structure import Structure
10from icet.tools.geometry import chemical_symbols_to_numbers
11from mchammer.observers.base_observer import BaseObserver
14class ClusterCountObserver(BaseObserver):
15 """This class represents a cluster count observer.
17 A cluster count observer enables one to keep track of the
18 occupation of clusters along the trajectory sampled by a Monte
19 Carlo (MC) simulation. For example, given several canonical MC
20 simulations representing different temperatures, this observer can
21 be used to access the temperature dependence of the number of
22 nearest neigbhors of a particular species.
24 The outputted cluster counts come in a dictionary format which keys like
25 ``0_Al``, ``0_Cu``, ``1_Al_Al``, ``1_Al_Cu``, ``1_Cu_Al``, ``1_Cu_Cu``, etc,
26 where the number indicates the orbit index and the symbols indicate the elements for this orbit.
27 The values of the dictionary contains the number of the such clusters found in the structure.
29 Parameters
30 ----------
31 cluster_space
32 Cluster space to define the clusters to be counted.
33 structure
34 Defines the lattice that the observer will work on.
35 orbit_indices
36 Only include orbits up to the orbit with this index.
37 By default all orbits are included.
38 interval
39 Observation interval. Defaults to ``None`` meaning that if the
40 observer is used in a Monte Carlo simulations, then the :class:`Ensemble` object
41 will determine the interval.
42 """
44 def __init__(self, cluster_space: ClusterSpace,
45 structure: Atoms,
46 interval: int = None,
47 orbit_indices: List[int] = None) -> None:
48 super().__init__(interval=interval, return_type=dict, tag='ClusterCountObserver')
50 self._cluster_space = cluster_space
51 local_orbit_list_generator = LocalOrbitListGenerator(
52 orbit_list=cluster_space.orbit_list,
53 structure=Structure.from_atoms(structure),
54 fractional_position_tolerance=cluster_space.fractional_position_tolerance)
56 self._full_orbit_list = local_orbit_list_generator.generate_full_orbit_list()
58 if orbit_indices is None:
59 self._orbit_indices = list(range(len(self._full_orbit_list)))
60 elif not isinstance(orbit_indices, Iterable):
61 raise ValueError('Argument orbit_indices should be a list of integers, '
62 f'not {type(orbit_indices)}')
63 else:
64 self._orbit_indices = orbit_indices
66 self._possible_occupations = self._get_possible_occupations()
68 def _get_possible_occupations(self) -> Dict[int, List[Tuple[str]]]:
69 """ Returns a dictionary containing the possible occupations for each orbit. """
70 possible_occupations = {}
71 for i in self._orbit_indices:
72 possible_occupations_orbit = self._cluster_space.get_possible_orbit_occupations(i)
73 order = self._full_orbit_list.get_orbit(i).order
74 assert order == len(possible_occupations_orbit[0]), \
75 f'Order (n={order}) does not match possible occupations' \
76 f' (n={len(possible_occupations[0])}, {possible_occupations}).'
77 possible_occupations[i] = possible_occupations_orbit
78 return possible_occupations
80 def get_cluster_counts(self, structure: Atoms) -> pd.DataFrame:
81 """Counts the number of times different clusters appear in the structure
82 and returns this information as a pandas dataframe.
84 Parameters
85 ----------
86 structure
87 input atomic structure.
88 """
89 rows = []
90 structure_icet = Structure.from_atoms(structure)
91 for i in self._orbit_indices:
92 orbit = self._full_orbit_list.get_orbit(i)
93 cluster_counts = orbit.get_cluster_counts(structure_icet)
94 for chemical_symbols in self._possible_occupations[i]:
95 count = cluster_counts.get(tuple(chemical_symbols_to_numbers(chemical_symbols)), 0)
96 row = {}
97 row['dc_tag'] = '{}_{}'.format(i, '_'.join(chemical_symbols))
98 row['occupation'] = chemical_symbols
99 row['cluster_count'] = count
100 row['orbit_index'] = i
101 row['order'] = orbit.order
102 rows.append(row)
103 return pd.DataFrame(rows)
105 def get_observable(self, structure: Atoms) -> dict:
106 """
107 Returns the value of the property from a cluster expansion model
108 for a given atomic configuration.
110 Parameters
111 ----------
112 structure
113 input atomic structure
114 """
115 counts = self.get_cluster_counts(structure)
116 count_dict = {row['dc_tag']: row['cluster_count']
117 for i, row in counts.iterrows()}
118 return count_dict