divakaivan commited on
Commit
908706b
1 Parent(s): 9816c8e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -123,6 +123,7 @@ dataset = dataset.map(
123
  prepare_dataset, remove_columns=dataset.column_names,
124
  )
125
 
 
126
 
127
  def predict(text, speaker):
128
  if len(text.strip()) == 0:
@@ -135,7 +136,7 @@ def predict(text, speaker):
135
  input_ids = input_ids[..., :model.config.max_text_positions]
136
 
137
  ### ### ###
138
- example = dataset.iloc[0]
139
  speaker_embeddings = torch.tensor(example["speaker_embeddings"]).unsqueeze(0)
140
 
141
  speaker_embedding = torch.tensor(speaker_embedding).unsqueeze(0)
 
123
  prepare_dataset, remove_columns=dataset.column_names,
124
  )
125
 
126
+ dataset = dataset.train_test_split(test_size=0.1)
127
 
128
  def predict(text, speaker):
129
  if len(text.strip()) == 0:
 
136
  input_ids = input_ids[..., :model.config.max_text_positions]
137
 
138
  ### ### ###
139
+ example = dataset['test'][11]
140
  speaker_embeddings = torch.tensor(example["speaker_embeddings"]).unsqueeze(0)
141
 
142
  speaker_embedding = torch.tensor(speaker_embedding).unsqueeze(0)