Edit model card

GPT2 Zinc 87m

This is a GPT2 style autoregressive language model trained on ~480m SMILES strings from the ZINC database.

The model has ~87m parameters and was trained for 175000 iterations with a batch size of 3072 to a validation loss of ~.615. This model is useful for generating druglike molecules or generating embeddings from SMILES strings

How to use

from transformers import GPT2TokenizerFast, GPT2LMHeadModel

tokenizer = GPT2TokenizerFast.from_pretrained("entropy/gpt2_zinc_87m", max_len=256)
model = GPT2LMHeadModel.from_pretrained('entropy/gpt2_zinc_87m')

To generate molecules:

inputs = torch.tensor([[tokenizer.bos_token_id]])

gen = model.generate(
              inputs,
              do_sample=True, 
              max_length=256, 
              temperature=1.,
              early_stopping=True,
              pad_token_id=tokenizer.pad_token_id,
              num_return_sequences=32
                         )
smiles = tokenizer.batch_decode(gen, skip_special_tokens=True)

To compute embeddings:

from transformers import DataCollatorWithPadding

collator = DataCollatorWithPadding(tokenizer, padding=True, return_tensors='pt')

inputs = collator(tokenizer(smiles))
outputs = model(**inputs, output_hidden_states=True)
full_embeddings = outputs[-1][-1]
mask = inputs['attention_mask']
embeddings = ((full_embeddings * mask.unsqueeze(-1)).sum(1) / mask.sum(-1).unsqueeze(-1))

WARNING

This model was trained with bos and eos tokens around SMILES inputs. The GPT2TokenizerFast tokenizer DOES NOT ADD special tokens, even when add_special_tokens=True. Huggingface says this is intended behavior.

It may be necessary to manually add these tokens

inputs = collator(tokenizer([tokenizer.bos_token+i+tokenizer.eos_token for i in smiles]))

Model Performance

To test generation performance, 1m compounds were generated at various temperature values. Generated compounds were checked for uniqueness and structural validity.

  • percent_unique denotes n_unique_smiles/n_total_smiles
  • percent_valid denotes n_valid_smiles/n_unique_smiles
  • percent_unique_and_valid denotes n_valid_smiles/n_total_smiles
temperature percent_unique percent_valid percent_unique_and_valid
0.5 0.928074 1 0.928074
0.75 0.998468 0.999967 0.998436
1 0.999659 0.999164 0.998823
1.25 0.999514 0.99351 0.993027
1.5 0.998749 0.970223 0.96901

Property histograms computed over 1m generated compounds: property histograms

Downloads last month
423
Inference API
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.