File size: 27,361 Bytes
d14eae5
 
 
 
 
 
 
 
 
f294ad1
d14eae5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
import torch
import torch.nn as nn
from functools import partial, cache
from argparse import Namespace
from typing import List, Tuple, Dict, Union, Optional
from itertools import chain
import random
from typing import Literal

from transformers import T5Tokenizer

class Graph():
    """
    A graph class.
    :param g: A list of tuples, where each tuple is a triple (head, r, tail).
    """
    def __init__(
            self, 
            g: List[Tuple[str,str,str]] = []
        ):
        self.g = g
        self.concepts = self.get_concepts()  # list of all concepts in the graph
        self.relations = self.get_relations()  # list of all relations in the graph
        self.relations_multiple = self.get_relations_multiple()  # list of all relations in the graph, including duplicate relations

    @property
    def g(self) -> List[Tuple[str,str,str]]:
        return self._g

    @g.setter
    def g(self, g: List[Tuple[str,str,str]]):
        self._g = g

    def num_triplets(self) -> int:
        """
        Get the number of triplets in the graph.
        """
        return len(self.g)

    def get_concepts(self) -> List[str]:
        """
        Get the concepts in the graph.
        """
        concepts = list(set([triplet[i] for triplet in self.g for i in [0, 2]]))
        concepts.sort()  # not necessary but makes debugging easier
        return concepts    
    
    def get_relations(self) -> List[str]:
        """
        Get the relations in the graph.
        """
        relations = list(set(self.get_relations_multiple()))
        relations.sort()  # not necessary but makes debugging easier
        return relations
    
    def get_relations_multiple(self) -> List[str]:
        """
        Get the relations in the graph, including duplicate relations.
        """
        relations = [triplet[1] for triplet in self.g]
        return relations

    def __str__(self):
        out_str = '\n'.join([str(triplet) for triplet in self.g])        
        return out_str

class Data(Namespace):
    def __init__(self, **kwargs):
        super().__init__()
        self.__dict__.update(kwargs)

def get_dummy_graph(num_triplets:int=3) -> Graph:
    g = [
        ("dog", "IsA", "animal"),
        ("cat", "IsA", "animal"),
        ("black poodle", "IsA", "dog"),
        ("black cat", "IsA", "cat"),
    ]
    assert num_triplets <=4, "num_triplets must be <= 4"
    g = g[:num_triplets]
    g = Graph(g)
    return g

def r2nl(r: str) -> str:
    """
    Convert a relation to a natural language string. Can be used to implement necessary changes in the data.
    """
    return r

def _get_str2tok(g:Graph, tokenizer: T5Tokenizer) -> dict[str, list[int]]:
    """
    Get a dictionary that maps strings to tokens.
    """
    # tokenize concepts and relations
    c_tok = tokenizer([r2nl(c) for c in g.concepts], padding=False)['input_ids']
    r_tok = tokenizer([r2nl(r) for r in g.relations], padding=False)['input_ids']

    tokens = c_tok + r_tok
    node_names = g.concepts + g.relations  # these are not necessarily all nodes in the Levi Graph, as relations can occur more than once
    assert len(tokens) == len(node_names), f"{len(tokens) = }, {len(node_names) = }"

    # remove end-of-sequence token
    tokens = [toks[:-1] if toks[-1] == tokenizer.eos_token_id else toks for toks in tokens]
    
    # create a dictionary mapping concepts and relations to their tokenized forms
    str2tok = {node: tok for node, tok in zip(node_names, tokens)}
    str2tok['</s>'] = [tokenizer.eos_token_id]
    return str2tok

def _get_graphT5_input_sequence(g:Graph, str2tok:dict, use_eos:bool) -> Tuple[list, dict]:
    # get input sequence (i.e. sequence that will be fed into the model for this graph)
    all_nodes = g.relations_multiple + g.concepts  # list of all concepts and relations that will be in the final sequence (i.e. all nodes of the Levi Graph)  # the order of nodes is first all relations (in the order that they appear in g.g), and then all concepts (in alphabetical order. though here the order is not important)
    
    if use_eos:
        all_nodes.append('</s>')

    all_tokens = [str2tok[node] for node in all_nodes]  # list of length #nodes, where each element is a list of token ids
    indices = {node: [] for node in all_nodes}  # dictionary mapping each node to its start-index and end- in the sequence. Keys are nodes, values are lists of tuples (start_index, end_index). The lists have a length of 1 for concepts and are as long as the number of occurances of the relation in the graph for relations.  # WARNING: this assumes that concepts and realtions have different names. This not always the case for REBEL. For concept_indices this is fixed. 
    num_relation_tokens = sum([len(token) for token in all_tokens[:len(g.relations_multiple)]])  # number of tokens that are relations
    num_concept_tokens = sum([len(token) for token in all_tokens[len(g.relations_multiple):len(g.relations_multiple)+len(g.concepts)]])  # number of tokens that are concepts
    num_eos_tokens = 1 if use_eos else 0

    is_concept = torch.tensor([False] * num_relation_tokens + [True] * num_concept_tokens + [False] * num_eos_tokens, dtype=torch.bool)  # tensor of length #nodes, where each element is True if the node is a concept and False if it is a relation
    index_counter = 0
    assert len(all_nodes) == len(all_tokens), (all_nodes, all_tokens)

    for node, token in zip(all_nodes, all_tokens):
        indices[node].append((index_counter, index_counter + len(token)))
        # assert is_concept[index_counter:index_counter+len(token)].all() == (node in g.concepts), f"{is_concept = }, {node = }, {g.concepts = }, {index_counter = }, {len(token) = }, {is_concept[index_counter:index_counter+len(token)] = }"
        index_counter += len(token)

    concept_indices = {node: [indices[node][-1]] for node in g.concepts}  # [-1] and reput in list in case relations have the same name as a concept (concepts are put in last). 
    sequence = torch.tensor(list(chain.from_iterable(all_tokens)), dtype=torch.long)
    sequence = sequence.unsqueeze(0)  # add batch dimension
    is_concept = is_concept.unsqueeze(0)  # add batch dimension
    return sequence, indices, is_concept, concept_indices

def _get_graphT5_relativeposition_sparsitymask(g:Graph, indices:dict, sequence_length:int, use_eos:bool, eos:str) -> Tuple[torch.Tensor, torch.Tensor]:
    ### get relative position of each node in the sequence, as well as the sparsity mask ###
    # initialize relative position matrix)
    relative_position = torch.zeros(size=(sequence_length, sequence_length), dtype=torch.long) 
    # initialize sparsity mask
    sparsity_mask = torch.zeros(size=(sequence_length, sequence_length), dtype=torch.bool)
    # initialize use_additional_bucket
    use_additional_bucket = torch.zeros(size=(sequence_length, sequence_length), dtype=torch.bool)
    
    # relative positions / sparsity within each node
    for start, end in chain.from_iterable(indices.values()):
        relative_position[start:end, start:end] = _get_relative_position(end-start)
        sparsity_mask[start:end, start:end] = True

    # relative position between nodes of the same triplet
    relation_counter = {relation: 0 for relation in g.relations}  # dictionary mapping each relation to the number of times it has already appeared in the graph
    for triplet in g.g:
        pos_h = indices[triplet[0]][0]  # position of head; tuple (start_index, end_index)
        pos_r = indices[triplet[1]][relation_counter[triplet[1]]]  # position of relation; tuple (start_index, end_index)
        pos_t = indices[triplet[2]][0]  # position of tail; tuple (start_index, end_index)
        
        l_h, l_r = pos_h[1] - pos_h[0], pos_r[1] - pos_r[0]  # length (i.e. number of tokens) of head and relation

        # iterate over all combinations of tokens in each triplet. This implementation is not very elegant, but it is sufficiently fast.
        for ih, ph in enumerate(range(pos_h[0], pos_h[1])):  # iterate over all head tokens
            for ir, pr in enumerate(range(pos_r[0], pos_r[1])):  # iterate over all relation tokens
                relative_position[ph, pr] = l_h - ih + ir
                relative_position[pr, ph] = - (l_h - ih + ir)
                sparsity_mask[ph, pr] = True
                sparsity_mask[pr, ph] = True
            for it, pt in enumerate(range(pos_t[0], pos_t[1])):  # iterate over all tail tokens
                relative_position[ph, pt] = l_h - ih + l_r + it
                relative_position[pt, ph] = - (l_h - ih + l_r + it)
                sparsity_mask[ph, pt] = True
                sparsity_mask[pt, ph] = True
        for ir, pr in enumerate(range(pos_r[0], pos_r[1])):  # iterate over all relation tokens
            for it, pt in enumerate(range(pos_t[0], pos_t[1])):  # iterate over all tail tokens
                relative_position[pr, pt] = l_r - ir + it
                relative_position[pt, pr] = - (l_r - ir + it)
                sparsity_mask[pr, pt] = True
                sparsity_mask[pt, pr] = True

        relation_counter[triplet[1]] += 1  # next time when that relation comes, then the next tokens will be used
    
    if use_eos:
        assert len(indices['</s>']) == 1, f"{indices['</s>'] = } should have length 1"
        pos_eos = indices['</s>'][0]  # position of head; tuple (start_index, end_index)
        assert pos_eos[0] + 1 == pos_eos[1], pos_eos
        pos_eos = pos_eos[0]  # position of eos token

        if eos == 'bidirectional':
            relative_position[:, pos_eos] = +1e6
            relative_position[pos_eos, :] = -1e6
            relative_position[pos_eos, pos_eos] = 0
            sparsity_mask[:, pos_eos] = True
            sparsity_mask[pos_eos, :] = True
        elif eos == 'unidirectional':
            relative_position[:, pos_eos] = 1e6
            relative_position[pos_eos, pos_eos] = 0
            sparsity_mask[pos_eos, :] = False  # no messages from eos to other tokens
            sparsity_mask[:, pos_eos] = True
        else:
            raise ValueError(f'{eos = } is not a valid option.')
    
    relative_position = relative_position.unsqueeze(0)  # add batch dimension
    sparsity_mask = sparsity_mask.unsqueeze(0)  # add batch dimension
    use_additional_bucket = use_additional_bucket.unsqueeze(0)  # add batch dimension
    return relative_position, sparsity_mask, use_additional_bucket

def _get_global_graphT5_relativeposition_sparsitymask(g:Graph, indices:dict, sequence_length:int, use_eos:bool, eos:str) -> Tuple[torch.Tensor, torch.Tensor]:
    ### get relative position of each node in the sequence, as well as the sparsity mask ###
    # initialize relative position matrix)
    # relative_position = torch.ones(size=(sequence_length, sequence_length), dtype=torch.long) * 1e6  # technically should be float('inf'), but it does not matter
    relative_position = torch.zeros(size=(sequence_length, sequence_length), dtype=torch.long)
    # initialize sparsity mask
    sparsity_mask = torch.ones(size=(sequence_length, sequence_length), dtype=torch.bool)  # could switch to None, but then code has to be updated accordingly (in particular get_batch)
    # initialize use_additional_bucket
    use_additional_bucket = torch.ones(size=(sequence_length, sequence_length), dtype=torch.bool)

    # relative positions / sparsity within each node
    for start, end in chain.from_iterable(indices.values()):
        relative_position[start:end, start:end] = _get_relative_position(end-start)
        use_additional_bucket[start:end, start:end] = False

    # relative position between nodes of the same triplet
    relation_counter = {relation: 0 for relation in g.relations}  # dictionary mapping each relation to the number of times it has already appeared in the graph
    for triplet in g.g:
        pos_h = indices[triplet[0]][0]  # position of head; tuple (start_index, end_index)
        pos_r = indices[triplet[1]][relation_counter[triplet[1]]]  # position of relation; tuple (start_index, end_index)
        pos_t = indices[triplet[2]][0]  # position of tail; tuple (start_index, end_index)
        
        l_h, l_r = pos_h[1] - pos_h[0], pos_r[1] - pos_r[0]  # length (i.e. number of tokens) of head and relation

        # iterate over all combinations of tokens in each triplet. This implementation is not very elegant, but it works.
        for ih, ph in enumerate(range(pos_h[0], pos_h[1])):  # iterate over all head tokens
            for ir, pr in enumerate(range(pos_r[0], pos_r[1])):  # iterate over all relation tokens
                relative_position[ph, pr] = l_h - ih + ir
                relative_position[pr, ph] = - (l_h - ih + ir)
                use_additional_bucket[ph, pr] = False
                use_additional_bucket[pr, ph] = False
            for it, pt in enumerate(range(pos_t[0], pos_t[1])):  # iterate over all tail tokens
                relative_position[ph, pt] = l_h - ih + l_r + it
                relative_position[pt, ph] = - (l_h - ih + l_r + it)
                use_additional_bucket[ph, pt] = False
                use_additional_bucket[pt, ph] = False
        for ir, pr in enumerate(range(pos_r[0], pos_r[1])):  # iterate over all relation tokens
            for it, pt in enumerate(range(pos_t[0], pos_t[1])):  # iterate over all tail tokens
                relative_position[pr, pt] = l_r - ir + it
                relative_position[pt, pr] = - (l_r - ir + it)
                use_additional_bucket[pr, pt] = False
                use_additional_bucket[pt, pr] = False

        relation_counter[triplet[1]] += 1  # next time when that relation comes, then the next tokens will be used
        if use_eos:
            assert len(indices['</s>']) == 1, f"{indices['</s>'] = } should have length 1"
            pos_eos = indices['</s>'][0]  # position of head; tuple (start_index, end_index)
            assert pos_eos[0] + 1 == pos_eos[1], pos_eos
            pos_eos = pos_eos[0]  # position of eos token

            if eos == 'bidirectional':
                relative_position[:, pos_eos] = +1e6
                relative_position[pos_eos, :] = -1e6
                relative_position[pos_eos, pos_eos] = 0
                sparsity_mask[:, pos_eos] = True
                sparsity_mask[pos_eos, :] = True
                use_additional_bucket[:, pos_eos] = False
                use_additional_bucket[pos_eos, :] = False
            elif eos == 'unidirectional':
                relative_position[:, pos_eos] = 1e6
                relative_position[pos_eos, pos_eos] = 0
                sparsity_mask[pos_eos, :] = False  # no messages from eos to other tokens
                sparsity_mask[:, pos_eos] = True
                use_additional_bucket[:, pos_eos] = False
                use_additional_bucket[pos_eos, :] = False
            else:
                raise ValueError(f'{eos = } is not a valid option.')
    
    relative_position = relative_position.unsqueeze(0)  # add batch dimension
    sparsity_mask = sparsity_mask.unsqueeze(0)  # add batch dimension
    use_additional_bucket = use_additional_bucket.unsqueeze(0)  # add batch dimension
    return relative_position, sparsity_mask, use_additional_bucket

def graph_to_graphT5(g:Graph, tokenizer:T5Tokenizer, how:str, eos:str)->Data:
    """
    Convert a graph to a graphT5 input.
    :param g: graph
    :param tokenizer: tokenizer
    :param how: how to represent the graph. Can be 'local' or 'global' for lGLM and gGLM respectively.
    :param eos: end-of-sequence token. Can be `False` for not using an eos token. When using an eos token, there are two ways to use it: `bidirectional` means that the eos token is connected to every other node in the graph, with a relative position of positive infinity (from node to eos) or negative infinity (from eos to node). `unidirectional` means that the eos token is connected to every node in the graph with a relative position of positive infinity (from node to eos), but not the other way around (i.e. no connection from eos to other node). This means, that nodes do not get messages from the eos token, which perceives locality when using the local GLM
    """
    if not isinstance(g, Graph):
        g = Graph(g)
    eos = str(eos)
    assert eos in ['False', 'bidirectional', 'unidirectional'], f"{eos = } must be either 'False', 'bidirectional', or 'unidirectional'"
    use_eos:bool = eos != 'False'

    str2tok = _get_str2tok(g, tokenizer)  # get a dictionary mapping concepts and relations to their tokenized forms

    sequence, indices, is_concept, concept_indices = _get_graphT5_input_sequence(g, str2tok, use_eos)  # get input sequence (i.e. sequence that will be fed into the model for this graph
    sequence_length = sequence.shape[1]

    if how == 'local':
        relative_position, sparsity_mask, use_additional_bucket = _get_graphT5_relativeposition_sparsitymask(g, indices, sequence_length, use_eos, eos)
        num_additional_buckets = 0  # lGLM does not use additional buckets
    elif how == 'global':
        relative_position, sparsity_mask, use_additional_bucket = _get_global_graphT5_relativeposition_sparsitymask(g, indices, sequence_length, use_eos, eos)
        num_additional_buckets = 1  # gGLM uses 1 additional bucket for long-ranged G2G connections
    else:
        raise ValueError(f"how must be either 'local' or 'global', but is {how}")

    input_ids = sequence

    data = Data(input_ids=input_ids, relative_position=relative_position, sparsity_mask=sparsity_mask, use_additional_bucket=use_additional_bucket, indices=indices, is_concept=is_concept, concept_indices=concept_indices, num_additional_buckets=num_additional_buckets)

    return data

@cache
def _get_relative_position(size):
    return torch.tensor([[i - j for i in range(size)] for j in range(size)], dtype=torch.long)

def get_embedding(
        sequence_embedding: torch.Tensor,
        indices: Dict[str, List[Tuple[int, int]]],
        concept: str,
        embedding_aggregation: str = "mean",
    ):
    """
    Returns the embedding of a concept.
    :param sequence_embedding: the embedding of the whole sequence. shape: (sequence_length, embedding_size)
    :param indices: dictionary mapping each node to its start-index and end- in the sequence. Keys are nodes, values are lists of tuples (start_index, end_index). The lists have a length of 1 for concepts.
    :param concept: the concept for which the embedding should be returned
    :param embedding_aggregation: how the embedding of a concept should be aggregated. Either "mean" or "seq". "mean" returns the mean of all tokens of the concept. "seq" returns the embeddings of the all token of the concept.
    :return: the aggregated embedding of the concept. shape (1, embedding_size) or (number_of_tokens, embedding_size). 
    """
    assert concept in indices.keys(), f"{concept = } is not a node in the graph. {indices = }"
    assert len(indices[concept]) == 1, f"{concept = } is not a concept, as concepts occur only once in the graph. {indices = }"

    start, end = indices[concept][0]
    sequence_embedding = sequence_embedding[start:end, :]
    if embedding_aggregation == "mean":
        return torch.mean(sequence_embedding, dim=0, keepdim=True)
    elif embedding_aggregation == "seq":
        return sequence_embedding
    else:
        raise NotImplementedError(f"{embedding_aggregation = } is not supported. Use either 'mean' or 'seq'.")

def add_text_to_graph_data(data, text, tokenizer, use_text):
    if use_text in {'False', '', False, None}:
        return None

    text_seq = torch.tensor(tokenizer(text, padding=False)['input_ids']).unsqueeze(0)
    new_input_ids = torch.cat([data.input_ids, text_seq], dim=1)

    old_seq_len = data.input_ids.shape[1]
    text_seq_len = text_seq.shape[1]
    new_seq_len = new_input_ids.shape[1]

    new_is_graph = torch.zeros(size=(1, new_seq_len), dtype=torch.bool)
    new_is_graph[:, :old_seq_len] = True

    if data.relative_position is None:  # sequence transformer
        assert data.sparsity_mask is None
        assert data.use_additional_bucket is None
        data.input_ids = new_input_ids
        data.is_graph = new_is_graph
        return None

    new_relative_position = torch.zeros(size=(1, new_seq_len, new_seq_len), dtype=data.relative_position.dtype)
    new_relative_position[:, :old_seq_len, :old_seq_len] = data.relative_position
    new_relative_position[:, old_seq_len:, old_seq_len:] = _get_relative_position(text_seq_len)

    new_sparsity_mask = torch.zeros(size=(1, new_seq_len, new_seq_len), dtype=data.sparsity_mask.dtype)
    new_sparsity_mask[:, :old_seq_len, :old_seq_len] = data.sparsity_mask
    new_sparsity_mask[:, old_seq_len:, old_seq_len:] = True
    
    new_use_additional_bucket = torch.zeros(size=(1, new_seq_len, new_seq_len), dtype=data.use_additional_bucket.dtype)
    new_use_additional_bucket[:, :old_seq_len, :old_seq_len] = data.use_additional_bucket
    new_use_additional_bucket[:, old_seq_len:, old_seq_len:] = False  # could change that if we want T2T and local G2G relations to be learned separately

    if use_text in {'FullyConnected', True}:
        new_sparsity_mask[:, old_seq_len:, :old_seq_len] = True
        new_sparsity_mask[:, :old_seq_len, old_seq_len:] = True

        new_use_additional_bucket[:, old_seq_len:, :old_seq_len] = True
        new_use_additional_bucket[:, :old_seq_len, old_seq_len:] = True

        new_relative_position[:, old_seq_len:, :old_seq_len] = data.num_additional_buckets
        new_relative_position[:, :old_seq_len, old_seq_len:] = data.num_additional_buckets + 1
        
        new_num_additional_buckets = data.num_additional_buckets + 2
    else:
        raise ValueError(f"unknown use_text {use_text} (type {type(use_text)})")

    data.input_ids = new_input_ids
    data.relative_position = new_relative_position
    data.sparsity_mask = new_sparsity_mask
    data.use_additional_bucket = new_use_additional_bucket
    data.num_additional_buckets = new_num_additional_buckets
    data.is_graph = new_is_graph
    return None

class DataProcessor():
    @staticmethod
    def encode_graph(tokenizer, g:Union[Graph,list[tuple[str,str,str]]], text:Optional[str]=None, how:Literal['global', 'local']='global', eos:str="False")->Data:
        """
        convert graph to suitable input for the model. 
        :param tokenizer: tokenizer
        :param g: graph
        :param text: text to add to the graph. Can be None if no text should be added. 
        :param how: how to represent the graph. Can be 'local' or 'global' for lGLM and gGLM respectively.
        :param eos: end-of-sequence token. Can be `False` for not using an eos token. This is the method used in the paper. When using an eos token, there are two ways to use it: `bidirectional` means that the eos token is connected to every other node in the graph. `unidirectional` means that the eos token is connected to every node in the graph (from node to eos), but not the other way around (i.e. no connection from eos to other node). This means, that nodes do not get messages from the eos token, which perceives locality when using the local GLM
        :return: Data object
        """
        if not isinstance(g, Graph):
            g = Graph(g)
        data = graph_to_graphT5(g, tokenizer, how, eos)
        if text is not None:
            add_text_to_graph_data(data, text, tokenizer, use_text=True)
        return data
    
    @staticmethod
    def to_batch(data_instances:list[Data], tokenizer, max_seq_len:Optional[int]=None, device:str='cpu', **kwargs)->dict: 
        """
        converts list of data instances to batched inputs for GLM forward call. 
        :param datas: list of Data instances
        :param max_seq_len: maximum sequence length
        :param tokenizer: tokenizer
        :param device: device
        :return: dictionary with keys 'input_ids', 'relative_position', 'sparsity_mask', and 'use_additional_bucket'
        """
        current_max_seq_len = max([data.input_ids.shape[1] for data in data_instances])
        if max_seq_len is None:
            max_seq_len = current_max_seq_len
        else:
            max_seq_len = min(max_seq_len, current_max_seq_len)

        if data_instances[0].relative_position is None:
            assert data_instances[0].sparsity_mask is None
            assert data_instances[0].use_additional_bucket is None
            is_sequence_transformer = True
        else:
            assert data_instances[0].sparsity_mask is not None
            assert data_instances[0].use_additional_bucket is not None
            is_sequence_transformer = False

        # intialize tensors
        input_ids = torch.ones((len(data_instances), max_seq_len), dtype=torch.long, device=device) * tokenizer.pad_token_id
        if is_sequence_transformer:
            relative_position = None 
            sparsity_mask = None 
            use_additional_bucket = None 
        else:
            relative_position = torch.zeros((len(data_instances), max_seq_len, max_seq_len), dtype=torch.long, device=device)
            sparsity_mask = torch.zeros((len(data_instances), max_seq_len, max_seq_len), dtype=torch.bool, device=device)
            use_additional_bucket = torch.zeros((len(data_instances), max_seq_len, max_seq_len), dtype=torch.bool, device=device)

        # fill tensors
        for i, data in enumerate(data_instances):
            instance_len = min(data.input_ids.shape[1], max_seq_len)
            input_ids[i, :instance_len] = data.input_ids[:, :instance_len]
            if not is_sequence_transformer:
                relative_position[i, :instance_len, :instance_len] = data.relative_position[:, :instance_len, :instance_len]
                sparsity_mask[i, :instance_len, :instance_len] = data.sparsity_mask[:, :instance_len, :instance_len]
                use_additional_bucket[i, :instance_len, :instance_len] = data.use_additional_bucket[:, :instance_len, :instance_len]

        model_input = {
            'input_ids': input_ids,
            'relative_position': relative_position,
            'sparsity_mask': sparsity_mask,
            'use_additional_bucket': use_additional_bucket,
            **kwargs
        }
        return model_input

    @staticmethod
    def get_embedding(sequence_embedding:torch.Tensor, indices:Dict[str,List[Tuple[int, int]]], concept:str, embedding_aggregation:str="mean"):
        """
        Returns embedding of a concept.
        :param sequence_embedding: the embedding of the whole sequence. shape: (sequence_length, embedding_size)
        :param indices: dictionary mapping each node to its start- and end-index in the sequence. Keys are nodes, values are lists of tuples (start_index, end_index). The lists have a length of 1 for concepts. indices is part of the Data object. 
        :param concept: the concept for which the embedding should be returned. 
        :param embedding_aggregation: how the embedding of a concept should be aggregated. Either "mean" or "seq". "mean" returns the mean of all tokens of the concept. "seq" returns the embeddings of the all token of the concept.
        :return: the aggregated embedding of the concept. shape (1, embedding_size) or (number_of_tokens, embedding_size).
        """
        return get_embedding(sequence_embedding, indices, concept, embedding_aggregation)