sw_test / app.py
Edward Beeching
fsa
bee7740
raw
history blame contribute delete
No virus
1.25 kB
import numpy as np
import time
import streamlit as st
from scienceworld import ScienceWorldEnv
st.title("ScienceWorld interactive demo")
hash_env = lambda _: None
import os
stream = os.popen('java -version')
output = stream.read()
st.write('output')
@st.cache(allow_output_mutation=True)
def load_env():
simplification_str = 'easy'
task_idx = None
print('Loading envs')
step_limit = 100
env = ScienceWorldEnv("", None, step_limit, 0)
if task_idx is None:
task_idx = 13
if isinstance(task_idx, int):
task_names = env.getTaskNames()
task_name = task_names[task_idx]
else:
task_name = task_idx
# Just reset to variation 0, as another call (e.g. reset_with_variation...) will setup
# an appropriate variation (train/dev/test)
env.load(task_name, 0, simplification_str)
obs, info = env.resetWithVariation(0, simplification_str)
return env, obs, info
class RandomAgent():
def act(self, info):
return np.random.choice(info['valid'])
num_episodes = 10
env, initial_obs, initial_info = load_env()
act = st.text_input('action to perform')
st.write(f'Action: {act}')
obs, reward, done, info = env.step(act)
st.write(f'Observation: {obs.strip()}')