Coverage for mchammer/observers/cluster_count_observer.py: 100%

49 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-05-06 04:14 +0000

1from typing import Dict, List, Tuple 

2from collections.abc import Iterable 

3 

4import pandas as pd 

5 

6from ase import Atoms 

7from icet import ClusterSpace 

8from icet.core.local_orbit_list_generator import LocalOrbitListGenerator 

9from icet.core.structure import Structure 

10from icet.tools.geometry import chemical_symbols_to_numbers 

11from mchammer.observers.base_observer import BaseObserver 

12 

13 

14class ClusterCountObserver(BaseObserver): 

15 """This class represents a cluster count observer. 

16 

17 A cluster count observer enables one to keep track of the 

18 occupation of clusters along the trajectory sampled by a Monte 

19 Carlo (MC) simulation. For example, given several canonical MC 

20 simulations representing different temperatures, this observer can 

21 be used to access the temperature dependence of the number of 

22 nearest neigbhors of a particular species. 

23 

24 Parameters 

25 ---------- 

26 cluster_space 

27 Cluster space to define the clusters to be counted. 

28 structure 

29 Defines the lattice that the observer will work on. 

30 orbit_indices 

31 Only include orbits up to the orbit with this index. 

32 By default all orbits are included. 

33 interval 

34 Observation interval. Defaults to ``None`` meaning that if the 

35 observer is used in a Monte Carlo simulations, then the :class:`Ensemble` object 

36 will determine the interval. 

37 """ 

38 

39 def __init__(self, cluster_space: ClusterSpace, 

40 structure: Atoms, 

41 interval: int = None, 

42 orbit_indices: List[int] = None) -> None: 

43 super().__init__(interval=interval, return_type=dict, tag='ClusterCountObserver') 

44 

45 self._cluster_space = cluster_space 

46 local_orbit_list_generator = LocalOrbitListGenerator( 

47 orbit_list=cluster_space.orbit_list, 

48 structure=Structure.from_atoms(structure), 

49 fractional_position_tolerance=cluster_space.fractional_position_tolerance) 

50 

51 self._full_orbit_list = local_orbit_list_generator.generate_full_orbit_list() 

52 

53 if orbit_indices is None: 

54 self._orbit_indices = list(range(len(self._full_orbit_list))) 

55 elif not isinstance(orbit_indices, Iterable): 

56 raise ValueError('Argument orbit_indices should be a list of integers, ' 

57 f'not {type(orbit_indices)}') 

58 else: 

59 self._orbit_indices = orbit_indices 

60 

61 self._possible_occupations = self._get_possible_occupations() 

62 

63 def _get_possible_occupations(self) -> Dict[int, List[Tuple[str]]]: 

64 """ Returns a dictionary containing the possible occupations for each orbit. """ 

65 possible_occupations = {} 

66 for i in self._orbit_indices: 

67 possible_occupations_orbit = self._cluster_space.get_possible_orbit_occupations(i) 

68 order = self._full_orbit_list.get_orbit(i).order 

69 assert order == len(possible_occupations_orbit[0]), \ 

70 f'Order (n={order}) does not match possible occupations' \ 

71 f' (n={len(possible_occupations[0])}, {possible_occupations}).' 

72 possible_occupations[i] = possible_occupations_orbit 

73 return possible_occupations 

74 

75 def get_cluster_counts(self, structure: Atoms) -> pd.DataFrame: 

76 """Counts the number of times different clusters appear in the structure 

77 and returns this information as a pandas dataframe. 

78 

79 Parameters 

80 ---------- 

81 structure 

82 input atomic structure. 

83 """ 

84 rows = [] 

85 structure_icet = Structure.from_atoms(structure) 

86 for i in self._orbit_indices: 

87 orbit = self._full_orbit_list.get_orbit(i) 

88 cluster_counts = orbit.get_cluster_counts(structure_icet) 

89 for chemical_symbols in self._possible_occupations[i]: 

90 count = cluster_counts.get(tuple(chemical_symbols_to_numbers(chemical_symbols)), 0) 

91 row = {} 

92 row['dc_tag'] = '{}_{}'.format(i, '_'.join(chemical_symbols)) 

93 row['occupation'] = chemical_symbols 

94 row['cluster_count'] = count 

95 row['orbit_index'] = i 

96 row['order'] = orbit.order 

97 rows.append(row) 

98 return pd.DataFrame(rows) 

99 

100 def get_observable(self, structure: Atoms) -> dict: 

101 """ 

102 Returns the value of the property from a cluster expansion model 

103 for a given atomic configuration. 

104 

105 Parameters 

106 ---------- 

107 structure 

108 input atomic structure 

109 """ 

110 counts = self.get_cluster_counts(structure) 

111 count_dict = {row['dc_tag']: row['cluster_count'] 

112 for i, row in counts.iterrows()} 

113 return count_dict