"""This module contains a class to obtain robocrystallographer ML features."""
from __future__ import annotations
from itertools import product
from matminer.featurizers.base import BaseFeaturizer
from numpy import mean
from pymatgen.analysis.local_env import cn_opt_params
from pymatgen.core.structure import Structure
from robocrys import StructureCondenser
from robocrys.featurize.adapter import FeaturizerAdapter
from robocrys.util import connected_geometries
_geometries = [geometry for cn in cn_opt_params.values() for geometry in cn]
_dimensionalities = (3, 2, 1, 0)
_dimensionality_sets = (
(3, 2, 1, 0),
(3, 2, 1),
(3, 2, 0),
(3, 1, 0),
(3, 2),
(3, 1),
(3, 0),
(2, 1, 0),
(2, 1),
(2, 0),
(1, 0),
)
_molecules = ["water", "oxygen", "ammonia", "methane"]
_connectivities = ["corner", "edge", "face"]
_cns = range(1, 13)
[docs]class RobocrysFeaturizer(BaseFeaturizer):
"""Class to generate structure features from robocrystallographer output.
Args:
condenser_kwargs: Keyword arguments that will be passed to
:obj:`robocrys.condense.StructureCondenser`.
distorted_tol: The value under which the site geometry will be
classified as distorted.
"""
def __init__(
self, condenser_kwargs: dict | None = None, distorted_tol: float = 0.6
):
condenser_kwargs = condenser_kwargs if condenser_kwargs else {}
self._sc = StructureCondenser(**condenser_kwargs)
self._distorted_tol = distorted_tol
[docs] def featurize(self, s: Structure) -> list[float | bool | str]:
"""Featurizes a structure using robocrystallographer.
Args:
s: A structure.
Returns:
The robocrystallographer features.
"""
fa = FeaturizerAdapter(
self._sc.condense_structure(s), distorted_tol=self._distorted_tol
)
# add general structure features
features = [
fa.mineral["type"],
fa.spg_symbol,
fa.crystal_system,
fa.dimensionality,
fa.is_vdw_heterostructure,
fa.is_interpenetrated,
fa.is_intercalated,
]
# add dimensionality features
features += [fa.is_dimensionality(d) for d in _dimensionalities]
features += [fa.is_dimensionality(d) for d in _dimensionality_sets]
features += [d in fa.component_dimensionalities for d in _dimensionalities]
# add molecule features
features += [fa.contains_named_molecule]
features += [fa.contains_molecule(m) for m in _molecules]
# add geometry features
features += [fa.contains_geometry_type(g) for g in _geometries]
features += [fa.contains_geometry_type(g, distorted=True) for g in _geometries]
features += [
fa.average_coordination_number,
fa.average_cation_coordination_number,
fa.average_anion_coordination_number,
]
# add polyhedral features
features += [
fa.contains_polyhedra,
fa.contains_corner_sharing_polyhedra,
fa.contains_edge_sharing_polyhedra,
fa.contains_face_sharing_polyhedra,
]
# add connectivity features
features += [
fa.contains_connected_geometry(c, g)
for c, g in product(_connectivities, connected_geometries)
]
features += [fa.average_corner_sharing_octahedral_tilt_angle]
# add fractional features
features += [fa.frac_sites_polyhedra]
features += [fa.frac_site_geometry(g) for g in _geometries]
features += [fa.frac_sites_n_coordinate(n) for n in _cns]
all_distances = fa.all_bond_lengths()
# add bond length features
features += [max(all_distances), min(all_distances), mean(all_distances)]
return features
[docs] def feature_labels(self):
# general features
labels = [
"mineral_prototype",
"spg_symbol",
"crystal_system",
"dimensionality",
"is_vdw_heterostructure",
"is_interpenetrated",
"is_intercalated",
]
# dimensionality features
labels += [f"is_only_{d}d" for d in _dimensionalities]
labels += [
"is_{}".format("_".join([f"{d}d" for d in ds]))
for ds in _dimensionality_sets
]
labels += [f"contains_{d}d_component" for d in _dimensionalities]
# molecule features
labels += ["contains_named_molecule"]
labels += [f"contains_{m}" for m in _molecules]
# geometry features
labels += [f"contains_{g}" for g in _geometries]
labels += [f"contains_distorted_{g}" for g in _geometries]
labels += ["average_site_cn", "average_cation_cn", "average_anion_cn"]
# polyhedral features
labels += [
"contains_polyhedra",
"contains_corner_sharing_polyhedra",
"contains_edge_sharing_polyhedra",
"contains_face_sharing_polyhedra",
]
# connectivity features
labels += [
f"contains_{c}_{g}"
for c, g in product(_connectivities, connected_geometries)
]
labels += ["corner_sharing_octahedral_tilt_angle"]
# fractional features
labels += ["frac_site_polyhedra"]
labels += [f"frac_sites_{g}" for g in _geometries]
labels += [f"frac_sites_{n}_coordinate" for n in _cns]
# bond length features
labels += ["max_bond_length", "min_bond_length", "average_bond_length"]
return labels
[docs] def citations(self):
return ["in prep."]
[docs] def implementors(self):
return ["Alex Ganose"]