File size: 3,499 Bytes
6df0c1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b35ab2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import asyncio
import os
import threading
from threading import Event
from typing import Optional

import discord
import gradio as gr
from discord import Permissions
from discord.ext import commands
from discord.utils import oauth_url

import gradio_client as grc
from gradio_client.utils import QueueError

event = Event()

DISCORD_TOKEN = os.getenv("DISCORD_TOKEN")

async def wait(job):
    while not job.done():
        await asyncio.sleep(0.2)

def get_client(session: Optional[str] = None) -> grc.Client:
    client = grc.Client("huggingface-projects/transformers-musicgen", hf_token=os.getenv("HF_TOKEN"))
    if session:
        client.session_hash = session
    return client

intents = discord.Intents.default()
intents.message_content = True
bot = commands.Bot(command_prefix="/", intents=intents)

@bot.event
async def on_ready():
    print(f"Logged in as {bot.user} (ID: {bot.user.id})")
    synced = await bot.tree.sync()
    print(f"Synced commands: {', '.join([s.name for s in synced])}.")
    event.set()
    print("------")


#-----------------------------------------------------------------------
@client.hybrid_command(
    name="testm", 
    description="Enter a prompt to generate music!",
)
async def musicgen_command(ctx, prompt: str, seed: int = None):
    """Generates music based on a prompt"""
    if ctx.author.id == bot.user.id:
        return    
    if seed is None:
        seed = random.randint(1, 10000)
    try:
        await music_create(ctx, prompt, seed)
    except Exception as e:
        print(f"Error: {e}")
#-----------------------------------------------------------------------
async def music_create(ctx, prompt, seed):
    """Runs music_create_job in executor"""
    try:
        message = await ctx.send(f"**{prompt}** - {ctx.author.mention} Generating...")

        loop = asyncio.get_running_loop()
        job = await loop.run_in_executor(None, music_create_job, prompt, seed)

        try:
            job.result()
            files = job.outputs()
            media_files = files[0]
        except QueueError:
            await ctx.send("The gradio space powering this bot is really busy! Please try again later!")

        audio = media_files[0]
        video = media_files[1]
        short_filename = prompt[:20]
        audio_filename = f"{short_filename}.mp3"
        video_filename = f"{short_filename}.mp4"

        with open(video, "rb") as file:
            discord_video_file = discord.File(file, filename=video_filename)
        await ctx.send(file=discord_video_file)

        with open(audio, "rb") as file:
            discord_audio_file = discord.File(file, filename=audio_filename)
        await ctx.send(file=discord_audio_file)

    except Exception as e:
        print(f"music_create Error: {e}")        


def music_create_job(prompt, seed):
    """Generates music based on a given prompt"""
    try:
        job = musicgen.submit(prompt, seed, api_name="/predict")
        while not job.done():
            pass
        return job

    except Exception as e:
        print(f"music_create_job Error: {e}")


#---------------------------------------------------------------------
def run_bot():
    if not DISCORD_TOKEN:
        print("DISCORD_TOKEN NOT SET")
        event.set()
    else:
        bot.run(DISCORD_TOKEN)


threading.Thread(target=run_bot).start()

event.wait()

with gr.Blocks() as demo:
    gr.Markdown(
    """
    # Discord bot of https://huggingface.co/spaces/facebook/MusicGen
    """
    )

demo.launch()