torch.concatenate -> torch.cat for compatibility
This commit is contained in:
@@ -3608,7 +3608,7 @@
|
||||
" with torch.no_grad():\n",
|
||||
" logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))\n",
|
||||
"\n",
|
||||
" weights = torch.concatenate(QKs) # layers * heads * tokens * frames \n",
|
||||
" weights = torch.cat(QKs) # layers * heads * tokens * frames \n",
|
||||
" weights = weights[:, :, :, : duration // AUDIO_SAMPLES_PER_TOKEN].cpu()\n",
|
||||
" weights = medfilt(weights, (1, 1, 1, medfilt_width))\n",
|
||||
" weights = torch.tensor(weights * qk_scale).softmax(dim=-1)\n",
|
||||
|
||||
Reference in New Issue
Block a user