whistlegen_v2 / app.py
johnowhitaker's picture
Update app.py
ac57ec9
raw
history blame contribute delete
No virus
3.47 kB
import gradio as gr
import json, urllib
from model import GPT, GPTConfig
from utils import sample
import torch
import pickle
device = torch.device('cpu')
# Create the model
vocab_size=147
block_size=128
mconf = GPTConfig(vocab_size, block_size,
n_layer=6, n_head=8, n_embd=256)
model = GPT(mconf)
# Load checkpoint
model.load_state_dict(torch.load('another_epoch_1.75total.ckpt', map_location=device))
# Vocab
stoi = pickle.load(open('stoi.pkl', 'rb'))
itos = pickle.load(open('itos.pkl', 'rb'))
# Post-process generation
# Completion
def completion_to_song(c):
lines = c.split('\n')
kept_lines = []
notes = False
for l in lines:
# Record if we've hit music
if '|' in l:
notes = True
# Stop if we then go back to the start of another song
if 'T' in l and notes:
break
if 'T' in l and notes:
break
# Stop on an empty line
if len(l.strip()) < 2 and notes:
break
# Otherwise keep the line
kept_lines.append(l)
return '\n'.join(kept_lines)
# Generate function
def generate_song(randomize, title, nu, ks, key):
# Start sequence
context = b"""T:"""
if not randomize:
context += bytes(title+'\n', 'utf-8')
context += bytes('M:'+ks+'\n', 'utf-8')
context += bytes('K:'+key+'\n', 'utf-8')
context += bytes('L:'+nu+'\n', 'utf-8')
# Model inputs
x = torch.tensor([stoi[s] for s in context], dtype=torch.long)[None,...].to(device)
# Completion
y = sample(model, x, 400, temperature=1.0, sample=True, top_k=10)[0]
completion = ''.join([chr(itos[int(i)]) for i in y])
# Return the first song
song = completion_to_song(completion)
html_song = song.replace('\n', '<br>')
url_song = urllib.parse.quote(song, safe='~@#$&()*!+=:;,?/\'')
html_text = '<p><a href="https://editor.drawthedots.com?t='+url_song+'" target="_blank"><b>EDIT LINK - click to open abcjs editor (allows download and playback)</b></a></p>'+"<p>"+html_song+'</p>'
return html_text
# Gradio demo
demo = gr.Blocks()
with demo:
gr.Markdown("Quick demo for [WhistleGen v2](https://wandb.ai/johnowhitaker/whistlegen_v2/reports/WhistleGen-v2--VmlldzoyMTAwNjAz) which lets you generate folk music using a transformer model. I can't get the javascript needed for rendering and playback working with gradio, so this shows the raw ABC notation from the model and a link to view it properly in an external editor.")
with gr.Row():
title = gr.Text(label='Title', value='The March of AI')
with gr.Column():
nu = gr.Text(label='Note unit', value='1/8')
with gr.Row():
key_signature = gr.Dropdown(['3/4', '4/4', '6/8', 'Random'], value='4/4', label='Time Signature')
with gr.Column():
key = gr.Text(label='Key', value='D')
with gr.Row():
randomize = gr.Checkbox(label='Randomize (ignores settings above)', value=True)
with gr.Row():
out = gr.HTML(label="Output", value='Output should appear here (takes ~30s)')
btn = gr.Button("Run")
btn.click(fn=generate_song, inputs=[randomize, title, nu, key_signature, key], outputs=out)
with gr.Row():
gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=whistlegen_v2_space)")
gr.Markdown("This is currently using an early model. See the [report](https://wandb.ai/johnowhitaker/whistlegen_v2/reports/WhistleGen-v2--VmlldzoyMTAwNjAz) for training info and updates.")
demo.launch(enable_queue=True)