Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
3dd0e68
1
Parent(s):
94706c2
Update app.py
Browse files
app.py
CHANGED
|
@@ -23,7 +23,7 @@ def generate(
|
|
| 23 |
seq_len,
|
| 24 |
max_seq_len = 2048,
|
| 25 |
temperature = 0.9,
|
| 26 |
-
verbose=
|
| 27 |
return_prime=False,
|
| 28 |
):
|
| 29 |
|
|
@@ -34,25 +34,30 @@ def generate(
|
|
| 34 |
if verbose:
|
| 35 |
print("Generating sequence of max length:", seq_len)
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
torch_in = x.tolist()[0]
|
| 41 |
-
|
| 42 |
-
logits = torch.FloatTensor(session.run(None, {'input': [torch_in]})[0])[:, -1]
|
| 43 |
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
if return_prime:
|
| 57 |
return out[:, :]
|
| 58 |
|
|
@@ -134,7 +139,7 @@ def GenerateMIDI():
|
|
| 134 |
|
| 135 |
midi_data = TMIDIX.score2midi(output, text_encoding)
|
| 136 |
|
| 137 |
-
with open(f"Allegro-Music-Transformer-Music-Composition", 'wb') as f:
|
| 138 |
f.write(midi_data)
|
| 139 |
|
| 140 |
audio = synthesis(TMIDIX.score2opus(output), 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2')
|
|
@@ -191,8 +196,10 @@ if __name__ == "__main__":
|
|
| 191 |
parser.add_argument("--port", type=int, default=7860, help="gradio server port")
|
| 192 |
opt = parser.parse_args()
|
| 193 |
|
|
|
|
| 194 |
session = rt.InferenceSession('Allegro_Music_Transformer_Small_Trained_Model_56000_steps_0.9399_loss_0.7374_acc.onnx', providers=['CUDAExecutionProvider'])
|
| 195 |
-
|
|
|
|
| 196 |
load_javascript()
|
| 197 |
app = gr.Blocks()
|
| 198 |
with app:
|
|
|
|
| 23 |
seq_len,
|
| 24 |
max_seq_len = 2048,
|
| 25 |
temperature = 0.9,
|
| 26 |
+
verbose=False,
|
| 27 |
return_prime=False,
|
| 28 |
):
|
| 29 |
|
|
|
|
| 34 |
if verbose:
|
| 35 |
print("Generating sequence of max length:", seq_len)
|
| 36 |
|
| 37 |
+
max_len = seq_len
|
| 38 |
+
cur_len = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
+
bar = tqdm.tqdm(desc="generating", total=max_len - cur_len, disable=in_space)
|
| 41 |
+
with bar:
|
| 42 |
+
while cur_len < max_len:
|
| 43 |
+
|
| 44 |
+
x = out[:, -max_seq_len:]
|
| 45 |
+
|
| 46 |
+
torch_in = x.tolist()[0]
|
| 47 |
+
|
| 48 |
+
logits = torch.FloatTensor(session.run(None, {'input': [torch_in]})[0])[:, -1]
|
| 49 |
+
|
| 50 |
+
filtered_logits = logits
|
| 51 |
+
|
| 52 |
+
probs = F.softmax(filtered_logits / temperature, dim=-1)
|
| 53 |
+
|
| 54 |
+
sample = torch.multinomial(probs, 1)
|
| 55 |
+
|
| 56 |
+
out = torch.cat((out, sample), dim=-1)
|
| 57 |
+
|
| 58 |
+
cur_len += 1
|
| 59 |
+
bar.update(1)
|
| 60 |
+
|
| 61 |
if return_prime:
|
| 62 |
return out[:, :]
|
| 63 |
|
|
|
|
| 139 |
|
| 140 |
midi_data = TMIDIX.score2midi(output, text_encoding)
|
| 141 |
|
| 142 |
+
with open(f"Allegro-Music-Transformer-Music-Composition.mid", 'wb') as f:
|
| 143 |
f.write(midi_data)
|
| 144 |
|
| 145 |
audio = synthesis(TMIDIX.score2opus(output), 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2')
|
|
|
|
| 196 |
parser.add_argument("--port", type=int, default=7860, help="gradio server port")
|
| 197 |
opt = parser.parse_args()
|
| 198 |
|
| 199 |
+
print('Loading model...')
|
| 200 |
session = rt.InferenceSession('Allegro_Music_Transformer_Small_Trained_Model_56000_steps_0.9399_loss_0.7374_acc.onnx', providers=['CUDAExecutionProvider'])
|
| 201 |
+
print('Done!')
|
| 202 |
+
|
| 203 |
load_javascript()
|
| 204 |
app = gr.Blocks()
|
| 205 |
with app:
|