diff --git a/decanlp/models/common.py b/decanlp/models/common.py index e7ae6633..c4a4dda4 100644 --- a/decanlp/models/common.py +++ b/decanlp/models/common.py @@ -412,7 +412,7 @@ class Embedding(nn.Module): def forward(self, x, lengths=None, device=-1): if self.pretrained_embeddings is not None: - pretrained_embeddings = self.pretrained_embeddings[0](x).to(x.device).detach() + pretrained_embeddings = self.pretrained_embeddings[0](x.cpu()).to(x.device).detach() else: pretrained_embeddings = None if self.trained_embeddings is not None: