Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -70,9 +70,9 @@ def prepare_mlm_sample(text, mask_ratio=0.15):
|
|
| 70 |
print(f"Maskable indices count: {len(maskable_indices)}")
|
| 71 |
print(f"Mask ratio: {mask_ratio}")
|
| 72 |
|
| 73 |
-
# Calculate how many tokens to mask
|
| 74 |
-
#
|
| 75 |
-
num_to_mask = max(1,
|
| 76 |
print(f"Number of tokens to mask: {num_to_mask}")
|
| 77 |
|
| 78 |
# Randomly select indices to mask
|
|
@@ -256,15 +256,26 @@ def prepare_next_token_prediction():
|
|
| 256 |
full_hidden = original_text[len(masked_text):].strip()
|
| 257 |
|
| 258 |
# Tokenize the hidden part
|
| 259 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
ntp_state["full_text"] = full_hidden
|
| 261 |
ntp_state["revealed_text"] = ""
|
| 262 |
ntp_state["next_token_idx"] = 0
|
| 263 |
|
| 264 |
# Make sure we have tokens to predict
|
| 265 |
if not ntp_state["tokens"]:
|
| 266 |
-
|
| 267 |
-
|
|
|
|
| 268 |
prepare_next_token_prediction()
|
| 269 |
|
| 270 |
def check_ntp_answer(user_continuation):
|
|
@@ -275,6 +286,12 @@ def check_ntp_answer(user_continuation):
|
|
| 275 |
if not ntp_state["tokens"]:
|
| 276 |
prepare_next_token_prediction()
|
| 277 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
# No more tokens to predict
|
| 279 |
if ntp_state["next_token_idx"] >= len(ntp_state["tokens"]):
|
| 280 |
# Reset for next round
|
|
@@ -282,6 +299,7 @@ def check_ntp_answer(user_continuation):
|
|
| 282 |
|
| 283 |
# Get the next token to predict
|
| 284 |
next_token = ntp_state["tokens"][ntp_state["next_token_idx"]]
|
|
|
|
| 285 |
|
| 286 |
# Get user's prediction
|
| 287 |
user_text = user_continuation.strip()
|
|
@@ -289,6 +307,7 @@ def check_ntp_answer(user_continuation):
|
|
| 289 |
# Tokenize user's prediction to get their first token
|
| 290 |
user_tokens = tokenizer.tokenize(user_text)
|
| 291 |
user_token = user_tokens[0].lower() if user_tokens else ""
|
|
|
|
| 292 |
|
| 293 |
# Clean up tokens for comparison
|
| 294 |
next_token_clean = next_token.lower()
|
|
@@ -300,6 +319,7 @@ def check_ntp_answer(user_continuation):
|
|
| 300 |
|
| 301 |
# Check if correct
|
| 302 |
is_correct = (user_token == next_token_clean)
|
|
|
|
| 303 |
|
| 304 |
# Update stats
|
| 305 |
if is_correct:
|
|
@@ -307,7 +327,7 @@ def check_ntp_answer(user_continuation):
|
|
| 307 |
user_stats["ntp"]["total"] += 1
|
| 308 |
|
| 309 |
# Reveal this token and prepare for next
|
| 310 |
-
ntp_state["revealed_text"] +=
|
| 311 |
ntp_state["next_token_idx"] += 1
|
| 312 |
|
| 313 |
# Calculate overall accuracy
|
|
@@ -320,7 +340,7 @@ def check_ntp_answer(user_continuation):
|
|
| 320 |
feedback.append(f"✗ Not quite. The actual next token was '{next_token_clean}'")
|
| 321 |
|
| 322 |
# Show progress
|
| 323 |
-
feedback.append(f"\
|
| 324 |
|
| 325 |
# If there are more tokens, prompt for next
|
| 326 |
if ntp_state["next_token_idx"] < len(ntp_state["tokens"]):
|
|
|
|
| 70 |
print(f"Maskable indices count: {len(maskable_indices)}")
|
| 71 |
print(f"Mask ratio: {mask_ratio}")
|
| 72 |
|
| 73 |
+
# Calculate how many tokens to mask based on the mask ratio
|
| 74 |
+
# No arbitrary cap - use the actual percentage
|
| 75 |
+
num_to_mask = max(1, int(len(maskable_indices) * mask_ratio))
|
| 76 |
print(f"Number of tokens to mask: {num_to_mask}")
|
| 77 |
|
| 78 |
# Randomly select indices to mask
|
|
|
|
| 256 |
full_hidden = original_text[len(masked_text):].strip()
|
| 257 |
|
| 258 |
# Tokenize the hidden part
|
| 259 |
+
hidden_tokens = tokenizer.tokenize(full_hidden)
|
| 260 |
+
|
| 261 |
+
# Print debug info
|
| 262 |
+
print(f"NTP State setup:")
|
| 263 |
+
print(f" Full text: '{original_text}'")
|
| 264 |
+
print(f" Visible text: '{masked_text}'")
|
| 265 |
+
print(f" Hidden text: '{full_hidden}'")
|
| 266 |
+
print(f" Hidden tokens: {hidden_tokens}")
|
| 267 |
+
|
| 268 |
+
# Set up the NTP state
|
| 269 |
+
ntp_state["tokens"] = hidden_tokens
|
| 270 |
ntp_state["full_text"] = full_hidden
|
| 271 |
ntp_state["revealed_text"] = ""
|
| 272 |
ntp_state["next_token_idx"] = 0
|
| 273 |
|
| 274 |
# Make sure we have tokens to predict
|
| 275 |
if not ntp_state["tokens"]:
|
| 276 |
+
print("Warning: No tokens to predict, will try another sample")
|
| 277 |
+
# If we don't have tokens, get a new sample with a higher cut ratio
|
| 278 |
+
new_text = get_new_sample("ntp", 0.4) # Use higher cut ratio
|
| 279 |
prepare_next_token_prediction()
|
| 280 |
|
| 281 |
def check_ntp_answer(user_continuation):
|
|
|
|
| 286 |
if not ntp_state["tokens"]:
|
| 287 |
prepare_next_token_prediction()
|
| 288 |
|
| 289 |
+
# Print debug info
|
| 290 |
+
print(f"Current NTP state:")
|
| 291 |
+
print(f" Next token index: {ntp_state['next_token_idx']}")
|
| 292 |
+
print(f" Total tokens: {len(ntp_state['tokens'])}")
|
| 293 |
+
print(f" User input: '{user_continuation}'")
|
| 294 |
+
|
| 295 |
# No more tokens to predict
|
| 296 |
if ntp_state["next_token_idx"] >= len(ntp_state["tokens"]):
|
| 297 |
# Reset for next round
|
|
|
|
| 299 |
|
| 300 |
# Get the next token to predict
|
| 301 |
next_token = ntp_state["tokens"][ntp_state["next_token_idx"]]
|
| 302 |
+
print(f" Expected next token: '{next_token}'")
|
| 303 |
|
| 304 |
# Get user's prediction
|
| 305 |
user_text = user_continuation.strip()
|
|
|
|
| 307 |
# Tokenize user's prediction to get their first token
|
| 308 |
user_tokens = tokenizer.tokenize(user_text)
|
| 309 |
user_token = user_tokens[0].lower() if user_tokens else ""
|
| 310 |
+
print(f" User's tokenized input: {user_tokens}")
|
| 311 |
|
| 312 |
# Clean up tokens for comparison
|
| 313 |
next_token_clean = next_token.lower()
|
|
|
|
| 319 |
|
| 320 |
# Check if correct
|
| 321 |
is_correct = (user_token == next_token_clean)
|
| 322 |
+
print(f" Comparison: '{user_token}' vs '{next_token_clean}' -> {'Correct' if is_correct else 'Incorrect'}")
|
| 323 |
|
| 324 |
# Update stats
|
| 325 |
if is_correct:
|
|
|
|
| 327 |
user_stats["ntp"]["total"] += 1
|
| 328 |
|
| 329 |
# Reveal this token and prepare for next
|
| 330 |
+
ntp_state["revealed_text"] += tokenizer.convert_tokens_to_string([next_token])
|
| 331 |
ntp_state["next_token_idx"] += 1
|
| 332 |
|
| 333 |
# Calculate overall accuracy
|
|
|
|
| 340 |
feedback.append(f"✗ Not quite. The actual next token was '{next_token_clean}'")
|
| 341 |
|
| 342 |
# Show progress
|
| 343 |
+
feedback.append(f"\nText so far: {masked_text}{ntp_state['revealed_text']}")
|
| 344 |
|
| 345 |
# If there are more tokens, prompt for next
|
| 346 |
if ntp_state["next_token_idx"] < len(ntp_state["tokens"]):
|