GLM-flan-t5-small / README.md
plenz's picture
Create README.md
99c6a4c verified
metadata
language:
  - en
base_model: google/flan-t5-small

GLM-flan-t5-small

This model is designed to process text-attributed graphs, texts, and interleaved inputs of both. It applies the architectural changes from Graph Language Models to the encoder of google/flan-t5-small. The parameters are unchanged, meaning that the model should be trained to obtain best performance.

Paper abstract:

While Language Models (LMs) are the workhorses of NLP, their interplay with structured knowledge graphs (KGs) is still actively researched. Current methods for encoding such graphs typically either (i) linearize them for embedding with LMs – which underutilize structural information, or (ii) use Graph Neural Networks (GNNs) to preserve the graph structure – but GNNs cannot represent text features as well as pretrained LMs. In our work we introduce a novel LM type, the Graph Language Model (GLM), that integrates the strengths of both approaches and mitigates their weaknesses. The GLM parameters are initialized from a pretrained LM to enhance understanding of individual graph concepts and triplets. Simultaneously, we design the GLM’s architecture to incorporate graph biases, thereby promoting effective knowledge distribution within the graph. This enables GLMs to process graphs, texts, and interleaved inputs of both. Empirical evaluations on relation classification tasks show that GLM embeddings surpass both LM- and GNN-based baselines in supervised and zero-shot setting, demonstrating their versatility.

Usage

In the paper we evaluate the model as a graph (and text) encoder for (text-guided) relation classification on ConceptNet and WikiData subgraphs. However, the model can be used for any task that requires encoding text-attributed graphs, texts, or interleaved inputs of both. See Encoding Graphs and Texts for an example implementation.

As we build on the T5 architecture, the model can be combined with the T5 decoder for generation. See Generating from Graphs and Texts for an example implementation.

Note that the model is not trained for the new architecture, so it should be trained to obtain best performance.

Encoding Graphs and Texts

from transformers import AutoTokenizer, AutoModel

modelcard = 'plenz/GLM-flan-t5-small'

print('Load the model and tokenizer')
model = AutoModel.from_pretrained(modelcard, trust_remote_code=True, revision='main')
tokenizer = AutoTokenizer.from_pretrained(modelcard)

print('get dummy input (2 instances to show batching)')
graph_1 = [
    ('black poodle', 'is a', 'dog'),
    ('dog', 'is a', 'animal'),
    ('cat', 'is a', 'animal')
]
text_1 = 'The dog chased the cat.'

graph_2 = [
    ('dog', 'is a', 'animal'),
    ('dog', 'has', 'tail'),
    ('dog', 'has', 'fur'),
    ('fish', 'is a', 'animal'),
    ('fish', 'has', 'scales')
]
text_2 = None  # only graph for this instance

print('prepare model inputs')
how = 'global'  # can be 'global' or 'local', depending on whether the local or global GLM should be used. See paper for more details. 
data_1 = model.data_processor.encode_graph(tokenizer=tokenizer, g=graph_1, text=text_1, how=how)
data_2 = model.data_processor.encode_graph(tokenizer=tokenizer, g=graph_2, text=text_2, how=how)
datas = [data_1, data_2]
model_inputs = model.data_processor.to_batch(data_instances=datas, tokenizer=tokenizer, max_seq_len=None, device='cpu')

print('compute token encodings')
outputs = model(**model_inputs)

# get token embeddings
print('Sequence of tokens (batch_size, max_seq_len, embedding_dim):', outputs.last_hidden_state.shape)  # embeddings of all graph and text tokens. Nodes in the graph (e.g., dog) appear only once in the sequence.
print('embedding of `black poodle` in the first instance. Shape is (seq_len, embedding_dim):', model.data_processor.get_embedding(sequence_embedding=outputs.last_hidden_state[0], indices=data_1.indices, concept='black poodle', embedding_aggregation='seq').shape)  # embedding_aggregation can be 'seq' or 'mean'. 'seq' returns the sequence of embeddings (e.g., all tokens of `black poodle`), 'mean' returns the mean of the embeddings.

Generating from Graphs and Texts

from transformers import AutoTokenizer, AutoModel, T5ForConditionalGeneration

modelcard = 'plenz/GLM-flan-t5-small'  
modelcard_generation = 'google/flan-t5-small'  

print('load the model and tokenizer')
model_generation = T5ForConditionalGeneration.from_pretrained(modelcard_generation)
del model_generation.encoder  # we only need the decoder for generation. Deleting the encoder is optional, but saves memory.
model = AutoModel.from_pretrained(modelcard, trust_remote_code=True, revision='main')
tokenizer = AutoTokenizer.from_pretrained(modelcard)


print('get dummy input (2 instances to show batching)')
graph_1 = [
    ('black poodle', 'is a', 'dog'),
    ('dog', 'is a', 'animal'),
    ('cat', 'is a', 'animal')
]
text_1 = 'summarize: The black poodle chased the cat.'  # with T5 prefix

graph_2 = [
    ('dog', 'is a', 'animal'),
    ('dog', 'has', 'tail'),
    ('dog', 'has', 'fur'),
    ('fish', 'is a', 'animal'),
    ('fish', 'has', 'scales')
]
text_2 = "Dogs have <extra_id_0> and fish have <extra_id_1>. Both are <extra_id_2>."  # T5 MLM

print('prepare model inputs')
how = 'global'  # can be 'global' or 'local', depending on whether the local or global GLM should be used. See paper for more details. 
data_1 = model.data_processor.encode_graph(tokenizer=tokenizer, g=graph_1, text=text_1, how=how)
data_2 = model.data_processor.encode_graph(tokenizer=tokenizer, g=graph_2, text=text_2, how=how)
datas = [data_1, data_2]
model_inputs = model.data_processor.to_batch(data_instances=datas, tokenizer=tokenizer, max_seq_len=None, device='cpu')

print('compute token encodings')
outputs = model(**model_inputs)

print('generate conditional on encoded graph and text')
outputs = model_generation.generate(encoder_outputs=outputs, max_new_tokens=10)
print('generation 1:', tokenizer.decode(outputs[0], skip_special_tokens=True))
print('generation 2:', tokenizer.decode(outputs[1], skip_special_tokens=False)) 

Contact

More information can be found in our paper Graph Language Models or our GitHub repository.

If you have any questions or comments, please feel free to send us an email at plenz@cl.uni-heidelberg.de.

If this model is helpful for your work, please consider citing the paper:

@inproceedings{plenz-frank-2024-graph,
    title = "Graph Language Models",
    author = "Plenz, Moritz and Frank, Anette",
    booktitle = "Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics",
    year = "2024",
    address = "Bangkok, Thailand",
    publisher = "Association for Computational Linguistics",
}