ymzhang319's picture
init
7f2690b
raw
history blame
No virus
593 Bytes
import torch
class RawFeatsStage(object):
def __init__(self):
pass
def eval(self):
return self
def encode(self, c):
"""fake vqmodel interface because self.cond_stage_model should have something
similar to coord.py but even more `dummy`"""
# assert 0.0 <= c.min() and c.max() <= 1.0
info = None, None, c
return c, None, info
def decode(self, c):
return c
def get_input(self, batch, k):
x = batch[k]
x = x.permute(0, 2, 1).to(memory_format=torch.contiguous_format)
return x.float()