File size: 2,406 Bytes
2e605bf
 
 
 
 
7198503
 
2e605bf
7198503
2e605bf
 
 
 
 
7198503
2e605bf
7198503
 
2e605bf
 
7198503
 
2e605bf
7198503
 
2e605bf
 
 
7198503
 
2e605bf
 
b01e61c
7198503
 
 
 
b01e61c
2e605bf
 
 
 
 
 
7198503
 
2e605bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7198503
2e605bf
 
7198503
2e605bf
7198503
 
 
 
2e605bf
 
 
 
 
 
7198503
2e605bf
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import logging
import pathlib
import gradio as gr
import pandas as pd
from gt4sd.algorithms.generation.hugging_face import (
    HuggingFaceSeq2SeqGenerator,
    HuggingFaceGenerationAlgorithm
)
from transformers import AutoTokenizer

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

def run_inference(
    model_name_or_path: str,
    prefix: str,
    prompt: str,
    num_beams: int,
):

    config = HuggingFaceSeq2SeqGenerator(
        algorithm_version=model_name_or_path,
        prefix=prefix,
        prompt=prompt,
        num_beams=num_beams
    )

    model = HuggingFaceGenerationAlgorithm(config)
    tokenizer = AutoTokenizer.from_pretrained("t5-small")

    text = list(model.sample(1))[0]

    text = text.replace(prefix+prompt,"")
    text = text.split(tokenizer.eos_token)[0]
    text = text.replace(tokenizer.pad_token, "")
    text = text.strip()

    
    return text


if __name__ == "__main__":

    # Preparation (retrieve all available algorithms)
    models = ["text-chem-t5-small-standard", "text-chem-t5-small-augm",
              "text-chem-t5-base-standard", "text-chem-t5-base-augm"]

    # Load metadata
    metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")

    examples = pd.read_csv(metadata_root.joinpath("examples.csv"), header=None).fillna(
        ""
    )
    print("Examples: ", examples.values.tolist())

    with open(metadata_root.joinpath("article.md"), "r") as f:
        article = f.read()
    with open(metadata_root.joinpath("description.md"), "r") as f:
        description = f.read()

    demo = gr.Interface(
        fn=run_inference,
        title="Text-chem-T5 model",
        inputs=[
            gr.Dropdown(
                models,
                label="Language model",
                value="text-chem-t5-base-augm",
            ),
            gr.Textbox(
                label="Prefix", placeholder="A task-specific prefix", lines=1
            ),
            gr.Textbox(
                label="Text prompt",
                placeholder="I'm a stochastic parrot.",
                lines=1,
            ),
            gr.Slider(minimum=1, maximum=50, value=10, label="num_beams", step=1),
        ],
        outputs=gr.Textbox(label="Output"),
        article=article,
        description=description,
        examples=examples.values.tolist(),
    )
    demo.launch(debug=True, show_error=True)