BinKhoaLe1812 commited on
Commit
c15f7a9
·
verified ·
1 Parent(s): b49a32d

Upd trans dedup

Browse files
Files changed (1) hide show
  1. utils/translation.py +89 -28
utils/translation.py CHANGED
@@ -12,12 +12,48 @@ vi_en = None
12
  zh_en = None
13
 
14
  def _dedupe_repeats(s: str, n_min: int = 3, n_max: int = 7) -> str:
15
- """Collapse excessive repeated n-grams (3..7) and repeated phrases."""
16
  if not s:
17
  return s
 
18
  # Collapse repeated spaces/newlines
19
  s = re.sub(r"\s+", " ", s).strip()
20
- # Heuristic: remove runs of identical tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  tokens = s.split()
22
  out = []
23
  last = None
@@ -25,12 +61,14 @@ def _dedupe_repeats(s: str, n_min: int = 3, n_max: int = 7) -> str:
25
  if last is None or t.lower() != last.lower():
26
  out.append(t)
27
  last = t
28
- s2 = " ".join(out)
 
29
  # Limit consecutive duplicate n-grams
30
  for n in range(n_max, n_min - 1, -1):
31
  pattern = re.compile(r"(\b(?:\w+\s+){%d}\w+\b)(?:\s+\1){2,}" % (n - 1), flags=re.IGNORECASE)
32
- s2 = pattern.sub(r"\1", s2)
33
- return s2
 
34
 
35
 
36
  def _normalize_and_cap(s: str, cap: int = 512) -> str:
@@ -55,26 +93,49 @@ def _is_too_repetitive(s: str, threshold: float = 0.4) -> bool:
55
 
56
  def translate_query(text: str, lang_code: str) -> str:
57
  global vi_en, zh_en
58
- if lang_code == "vi":
59
- if vi_en is None:
60
- vi_en = pipeline("translation", model="VietAI/envit5-translation", src_lang="vi", tgt_lang="en", device=-1)
61
- raw = vi_en(text, max_length=512)[0]["translation_text"]
62
- cleaned = _dedupe_repeats(raw)
63
- norm = _normalize_and_cap(cleaned, cap=512)
64
- if _is_too_repetitive(norm):
65
- logger.warning("[En-Vi] Translation repetitive; falling back to original text")
66
- norm = text
67
- logger.info(f"[En-Vi] Query in `{lang_code}` translated to: {norm}")
68
- return norm
69
- elif lang_code == "zh":
70
- if zh_en is None:
71
- zh_en = pipeline("translation", model="Helsinki-NLP/opus-mt-zh-en", device=-1)
72
- raw = zh_en(text, max_length=512)[0]["translation_text"]
73
- cleaned = _dedupe_repeats(raw)
74
- norm = _normalize_and_cap(cleaned, cap=512)
75
- if _is_too_repetitive(norm):
76
- logger.warning("[En-Zh] Translation repetitive; falling back to original text")
77
- norm = text
78
- logger.info(f"[En-Zh] Query in `{lang_code}` translated to: {norm}")
79
- return norm
80
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  zh_en = None
13
 
14
  def _dedupe_repeats(s: str, n_min: int = 3, n_max: int = 7) -> str:
15
+ """Collapse excessive repeated n-grams and repeated phrases with improved logic."""
16
  if not s:
17
  return s
18
+
19
  # Collapse repeated spaces/newlines
20
  s = re.sub(r"\s+", " ", s).strip()
21
+
22
+ # More aggressive repetition detection
23
+ # Check for simple word repetition (like "a lot of people do not" repeated)
24
+ words = s.split()
25
+ if len(words) > 20: # Only check if text is long enough
26
+ # Look for repeated sequences of 3-8 words
27
+ for seq_len in range(8, 2, -1):
28
+ if len(words) < seq_len * 3: # Need at least 3 repetitions
29
+ continue
30
+
31
+ # Check each possible starting position
32
+ for start in range(len(words) - seq_len * 2):
33
+ sequence = words[start:start + seq_len]
34
+ # Count how many times this sequence repeats
35
+ repeat_count = 1
36
+ pos = start + seq_len
37
+ while pos + seq_len <= len(words):
38
+ if words[pos:pos + seq_len] == sequence:
39
+ repeat_count += 1
40
+ pos += seq_len
41
+ else:
42
+ break
43
+
44
+ # If we found 3+ repetitions, remove the excess
45
+ if repeat_count >= 3:
46
+ # Keep only the first occurrence
47
+ new_words = words[:start + seq_len] + words[start + seq_len * repeat_count:]
48
+ s = " ".join(new_words)
49
+ words = s.split()
50
+ break
51
+ else:
52
+ continue
53
+ break # Break outer loop if we found and fixed a repetition
54
+
55
+ # Additional cleanup for remaining patterns
56
+ # Remove consecutive identical word
57
  tokens = s.split()
58
  out = []
59
  last = None
 
61
  if last is None or t.lower() != last.lower():
62
  out.append(t)
63
  last = t
64
+ s = " ".join(out)
65
+
66
  # Limit consecutive duplicate n-grams
67
  for n in range(n_max, n_min - 1, -1):
68
  pattern = re.compile(r"(\b(?:\w+\s+){%d}\w+\b)(?:\s+\1){2,}" % (n - 1), flags=re.IGNORECASE)
69
+ s = pattern.sub(r"\1", s)
70
+
71
+ return s
72
 
73
 
74
  def _normalize_and_cap(s: str, cap: int = 512) -> str:
 
93
 
94
  def translate_query(text: str, lang_code: str) -> str:
95
  global vi_en, zh_en
96
+
97
+ if not text or not text.strip():
98
+ return text
99
+
100
+ try:
101
+ if lang_code == "vi":
102
+ if vi_en is None:
103
+ logger.info("[Translation] Loading Vietnamese-English model...")
104
+ vi_en = pipeline("translation", model="VietAI/envit5-translation", src_lang="vi", tgt_lang="en", device=-1)
105
+
106
+ # Limit input length to prevent model issues
107
+ input_text = text[:1000] if len(text) > 1000 else text
108
+ raw = vi_en(input_text, max_length=512)[0]["translation_text"]
109
+ cleaned = _dedupe_repeats(raw)
110
+ norm = _normalize_and_cap(cleaned, cap=512)
111
+
112
+ if _is_too_repetitive(norm) or len(norm.strip()) < 10:
113
+ logger.warning("[En-Vi] Translation repetitive or too short; falling back to original text")
114
+ return text
115
+
116
+ logger.info(f"[En-Vi] Query in `{lang_code}` translated to: {norm[:100]}...")
117
+ return norm
118
+
119
+ elif lang_code == "zh":
120
+ if zh_en is None:
121
+ logger.info("[Translation] Loading Chinese-English model...")
122
+ zh_en = pipeline("translation", model="Helsinki-NLP/opus-mt-zh-en", device=-1)
123
+
124
+ # Limit input length to prevent model issues
125
+ input_text = text[:1000] if len(text) > 1000 else text
126
+ raw = zh_en(input_text, max_length=512)[0]["translation_text"]
127
+ cleaned = _dedupe_repeats(raw)
128
+ norm = _normalize_and_cap(cleaned, cap=512)
129
+
130
+ if _is_too_repetitive(norm) or len(norm.strip()) < 10:
131
+ logger.warning("[En-Zh] Translation repetitive or too short; falling back to original text")
132
+ return text
133
+
134
+ logger.info(f"[En-Zh] Query in `{lang_code}` translated to: {norm[:100]}...")
135
+ return norm
136
+
137
+ except Exception as e:
138
+ logger.error(f"[Translation] Translation failed for {lang_code}: {e}")
139
+ return text # Fallback to original text
140
+
141
+ return text