File size: 9,021 Bytes
1c703f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4883dc5
1c703f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4883dc5
 
1c703f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
import numpy as np
import periodictable

from copy import deepcopy, copy
from .latmatcher import LatMatch, SR


class PipelineLatMatch:

    def __init__(self, Alat3D, Blat3D, Aatoms3D=None, Batoms3D=None, dim=10, sc_vec=None, optimize_angle=True,
                 optimize_strain=True, dz=4):

        self.Alat3D = Alat3D
        self.Blat3D = Blat3D
        self.Alat = self.Alat3D[:2, :2][:2, :2]
        self.Blat = self.Blat3D[:2, :2][:2, :2]
        self.dz = dz

        if sc_vec is not None:
            self.sc_vec = sc_vec
            # TODO: This should be computed
            self.rez = None
        else:
            self.matcher4 = LatMatch(scdim=dim, reference=self.Alat, target=self.Blat,
                                     optimize_angle=optimize_angle, optimize_strain=optimize_strain)
            self.sc_vec = self.matcher4.supercell()
            self.rez = self.matcher4.result

        self.sc_vec3 = np.array([[self.sc_vec[0][0], self.sc_vec[0][1], 0],
                        [self.sc_vec[1][0], self.sc_vec[1][1], 0],
                        [0, 0, self.dz+self.Alat3D[2][2]+self.Blat3D[2][2]+0.001]])

        self.Aatoms3D, self.Batoms3D = self.atom_shift(Aatoms3D, Batoms3D)

        self.superA_xyz = []
        self.superB_xyz = []

    def atom_shift(self, Aatoms3D, Batoms3D):

        shift_Aatoms3D = deepcopy(Aatoms3D)
        shift_Batoms3D = deepcopy(Batoms3D)

        za = []
        for atom in Aatoms3D:
            za.append(atom[1][-1])
        zb = []
        for atom in Batoms3D:
            zb.append(atom[1][-1])

        a_average = sum(za) / len(za)
        b_average = sum(zb) / len(zb)

        min_za=min(za)  # min(za)+x=0
        for i in range(len(za)):
            za[i] = za[i] - min_za+0.001
        min_zb = min(zb)  # min(za)+x=0
        for i in range(len(zb)):
            zb[i] = zb[i] - min_zb

        max_za = max(za)
        for i in range(len(zb)):
            zb[i] = zb[i] + max_za + self.dz+0.001

        for i in range(len(shift_Aatoms3D)):
            shift_Aatoms3D[i][1][-1] = za[i]
        for i in range(len(shift_Batoms3D)):
            shift_Batoms3D[i][1][-1] = zb[i]

        self.Aatoms3D = shift_Aatoms3D
        self.Batoms3D = shift_Batoms3D
        return self.Aatoms3D, self.Batoms3D

    def compute_super_atoms(self):

        atoms = atoms_to_greed(self.Aatoms3D, lat_v=self.Alat3D, dim=(10, 10, 0))  # initial grid of atoms xyz
        atoms_a = atom_change_basis2D(atoms, new_basis=self.Alat, old_basis=np.identity(2))  # grid of atoms in A basis
        atoms_A = atom_change_basis2D(atoms_a, new_basis=self.sc_vec,
                                      old_basis=self.Alat)  # grid of atoms in super cell basis
        superA = supar_atoms(atoms_A)  # select the atoms from super cell
        superA_xyz = atom_change_basis2D(superA, new_basis=np.identity(2),
                                         old_basis=self.sc_vec)  # move the atoms back to xyz basis

        atoms = atoms_to_greed(self.Batoms3D, lat_v=self.Blat3D, dim=(10, 10, 0))  # initial grid of atoms
        oBlat3D, atoms_b = rotate_guest(self.rez, self.Blat3D, atoms)
        atoms_B = atom_change_basis2D(atoms_b, new_basis=self.sc_vec,
                                      old_basis=np.identity(2))  # grid of atoms in super cell basis

        superB = supar_atoms(atoms_B)  # select the atoms from super cell
        superB_xyz = atom_change_basis2D(superB, new_basis=np.identity(2),
                                         old_basis=self.sc_vec)  # move the atoms back to xyz basis

        superA_xyz = uniq_list(superA_xyz)
        superB_xyz = uniq_list(superB_xyz)

        self.superA_xyz = superA_xyz
        self.superB_xyz = superB_xyz

        return self.superA_xyz, self.superB_xyz

    def get_new_structure(self, inter_distance=[0,0,0]):

        structure = {"atoms": [],
                     "lattice_vectors": self.sc_vec3,
                     "pbc": [True, True, False],
                     "positions": [],
                     "host_guest": [], }

        superA_xyz, superB_xyz = self.compute_super_atoms()

        atomic_numbers = []
        positions = []

        for element in superA_xyz:
            structure["host_guest"].append("host")
            symbol = element[0]
            positions.append(element[1].tolist())
            atomic_number = getattr(periodictable, symbol)
            if atomic_number is not None:
                atomic_numbers.append(atomic_number.number)
            else:
                print(f"Warning: Atomic number for element '{symbol}' not found.")
                atomic_numbers.append(None)  # Or handle this case as you see fit

        for element in superB_xyz:
            structure["host_guest"].append("guest")
            symbol = element[0]
            pos=element[1].tolist()+inter_distance

            positions.append(element[1].tolist())
            atomic_number = getattr(periodictable, symbol)
            if atomic_number is not None:
                atomic_numbers.append(atomic_number.number)
            else:
                print(f"Warning: Atomic number for element '{symbol}' not found.")
                atomic_numbers.append(None)  # Or handle this case as you see fit

        structure["atoms"] = atomic_numbers
        structure["positions"] = positions

        return structure


def is_element_in_list(element, lst):
    for item in lst:
        if item[0] == element[0] and np.allclose(item[1], element[1]):
            return True
    return False


def uniq_list(super_brut):
    uniq = []
    for c in super_brut:
        if is_element_in_list(c, uniq) is False:
            uniq.append(c)
    return uniq


def atoms_to_greed(atoms, lat_v, dim):
    """
    Construct a greed of atoms knowing the atoms lattice vectors and dim
    :param atoms:
    :param lat_v:
    :param dim:
    :return:
    """

    atom_list = deepcopy(atoms)

    # translation lat_v-x
    new_atoms = []
    for i in range(1, dim[0] + 1):
        for atom in atom_list:
            new_atom = deepcopy(atom)
            new_atom[1][0] += i * lat_v[0][0]
            new_atom[1][1] += i * lat_v[1][0]
            new_atom[1][2] += i * lat_v[2][0]
            new_atoms.append(new_atom)

            new_atom = deepcopy(atom)
            new_atom[1][0] -= i * lat_v[0][0]
            new_atom[1][1] -= i * lat_v[1][0]
            new_atom[1][2] -= i * lat_v[2][0]
            new_atoms.append(new_atom)

    atom_list.extend(new_atoms)
    # translation lat_v-y
    for i in range(1, dim[1] + 1):
        for atom in atom_list:
            new_atom = deepcopy(atom)
            new_atom[1][0] += i * lat_v[0][1]
            new_atom[1][1] += i * lat_v[1][1]
            new_atom[1][2] += i * lat_v[2][1]
            new_atoms.append(new_atom)

            new_atom = deepcopy(atom)
            new_atom[1][0] -= i * lat_v[0][1]
            new_atom[1][1] -= i * lat_v[1][1]
            new_atom[1][2] -= i * lat_v[2][1]
            new_atoms.append(new_atom)
    atom_list.extend(new_atoms)
    # translation lat_v-z
    for i in range(1, dim[2] + 1):
        for atom in atom_list:
            new_atom = deepcopy(atom)
            new_atom[1][0] += i * lat_v[0][2]
            new_atom[1][1] += i * lat_v[1][2]
            new_atom[1][2] += i * lat_v[2][2]
            new_atoms.append(new_atom)

            new_atom = deepcopy(atom)
            new_atom[1][0] -= i * lat_v[0][2]
            new_atom[1][1] -= i * lat_v[1][2]
            new_atom[1][2] -= i * lat_v[2][2]
            new_atoms.append(new_atom)
    atom_list.extend(new_atoms)

    return atom_list


# atoms in new basis:

def atom_change_basis2D(atoms, new_basis, old_basis=np.identity(2)):
    new_basis = [[new_basis[0][0], new_basis[0][1]],
                 [new_basis[1][0], new_basis[1][1]]]

    old_basis = [[old_basis[0][0], old_basis[0][1]],
                 [old_basis[1][0], old_basis[1][1]]]
    atom_list = deepcopy(atoms)
    change_base = np.linalg.inv(new_basis) @ old_basis

    new_atoms = []
    for atom in atom_list:
        new_atom = deepcopy(atom)
        nd = (change_base @ (np.array(new_atom[1][:2]).T)).T
        new_atom[1][0] = nd[0]
        new_atom[1][1] = nd[1]

        new_atoms.append(new_atom)

    atom_list = deepcopy(new_atoms)
    return atom_list


def supar_atoms(atoms, eps=0.01):
    atom_list = deepcopy(atoms)
    # change_base=np.linalg.inv(new_basis)

    new_atoms = []
    for atom in atom_list:
        if ((atom[1][0] >= -eps) and (atom[1][0] < 1)) and ((atom[1][1] >= -eps) and (atom[1][1] < 1)):
            new_atoms.append(atom)

    atom_list = deepcopy(new_atoms)
    return atom_list


def rotate_guest(rez, Blat3D, atoms):
    Batoms3D = [a[1] for a in atoms]
    s1, s2, theta = rez
    Tr = np.eye(3)
    Tr[:2, :2] = SR(s1, s2, theta)
    oBlat3D = copy(Blat3D)
    oBlat3D = (Tr @ (oBlat3D.T)).T

    rs = [Tr @ r for r in Batoms3D]

    new_atoms = [[atoms[i][0], rs[i]] for i in range(len(atoms))]

    return oBlat3D, new_atoms