File size: 1,523 Bytes
7c1eee1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from .imagenhub_models import load_imagenhub_model
from .playground_api import load_playground_model

IMAGE_GENERATION_MODELS = ['imagenhub_LCM_generation','imagenhub_SDXLTurbo_generation','imagenhub_SDXL_generation', 'imagenhub_PixArtAlpha_generation',
                            'imagenhub_OpenJourney_generation','imagenhub_SDXLLightning_generation', 'imagenhub_StableCascade_generation',
                            'playground_PlayGroundV2_generation', 'playground_PlayGroundV2.5_generation']
IMAGE_EDITION_MODELS = ['imagenhub_CycleDiffusion_edition', 'imagenhub_Pix2PixZero_edition', 'imagenhub_Prompt2prompt_edition',
                        'imagenhub_SDEdit_edition', 'imagenhub_InstructPix2Pix_edition', 'imagenhub_MagicBrush_edition', 'imagenhub_PNP_edition']


def load_pipeline(model_name):
    """
    Load a model pipeline based on the model name
    Args:
        model_name (str): The name of the model to load, should be of the form {source}_{name}_{type}
        the source can be either imagenhub or playground
        the name is the name of the model used to load the model
        the type is the type of the model, either generation or edition
    """
    model_source, model_name, model_type = model_name.split("_")
    if model_source == "imagenhub":
        pipe = load_imagenhub_model(model_name, model_type)
    elif model_source == "playground":
        pipe = load_playground_model(model_name)
    else:
        raise ValueError(f"Model source {model_source} not supported")
    return pipe