Skip to content
Snippets Groups Projects

Implements Ensamble model

Merged Roko Parac requested to merge ensamble into master
1 file
+ 18
0
Compare changes
  • Side-by-side
  • Inline
+ 18
0
@@ -43,3 +43,21 @@ def make_predictions(model, tokenizer, dataset, batch_size, device):
y_true = None
return y_pred, y_true
class EnsambleModel:
def __init__(self, models, tokenizers):
self.models = models
self.tokenizers = tokenizers
if len(models) % 2 == 0:
raise ValueError('Need an odd number of models to avoid ties')
def make_predictions(self, dataset, batch_size, device):
y_preds, y_true = torch.Tensor([0]).int(), None
for model, tokenizer in zip(self.models, self.tokenizers):
y_pred, y_true = make_predictions(model, tokenizer, dataset, batch_size, device)
y_preds += torch.Tensor(y_pred).int()
y_preds = (y_preds > len(self.models) // 2).int().tolist()
return y_preds, y_true
Loading