fnavales commited on
Commit
cd08337
1 Parent(s): 4d031cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -1
app.py CHANGED
@@ -19,7 +19,17 @@ class ToxicCommentTagger(pl.LightningModule):
19
  self.n_training_steps = n_training_steps
20
  self.n_warmup_steps = n_warmup_steps
21
  self.criterion = nn.BCELoss()
22
-
 
 
 
 
 
 
 
 
 
 
23
 
24
  def predict(model, tokenizer, sentence):
25
 
 
19
  self.n_training_steps = n_training_steps
20
  self.n_warmup_steps = n_warmup_steps
21
  self.criterion = nn.BCELoss()
22
+
23
+
24
+ def forward(self, input_ids, attention_mask, labels=None):
25
+ output = self.bert(input_ids, attention_mask=attention_mask)
26
+ output = self.classifier(output.pooler_output)
27
+ output = torch.sigmoid(output)
28
+ loss = 0
29
+ if labels is not None:
30
+ loss = self.criterion(output, labels)
31
+ return loss, output
32
+
33
 
34
  def predict(model, tokenizer, sentence):
35