Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import string | |
| import matplotlib.cm as cm | |
| import streamlit as st | |
| from charset_normalizer import detect | |
| from transformers import ( | |
| AutoModelForTokenClassification, | |
| AutoTokenizer, | |
| logging, | |
| pipeline, | |
| ) | |
| st.set_page_config(page_title="German Legal NER", page_icon="⚖️", layout="wide") | |
| logging.set_verbosity(logging.ERROR) | |
| st.markdown( | |
| """ | |
| <style> | |
| .block-container { | |
| padding-top: 1rem; | |
| padding-bottom: 5rem; | |
| padding-left: 3rem; | |
| padding-right: 3rem; | |
| } | |
| header, footer {visibility: hidden;} | |
| .entity { | |
| position: relative; | |
| display: inline-block; | |
| background-color: transparent; | |
| font-weight: normal; | |
| cursor: help; | |
| } | |
| .entity .tooltip { | |
| visibility: hidden; | |
| background-color: #333; | |
| color: #fff; | |
| text-align: center; | |
| border-radius: 4px; | |
| padding: 2px 6px; | |
| position: absolute; | |
| z-index: 1; | |
| bottom: 125%; | |
| left: 50%; | |
| transform: translateX(-50%); | |
| white-space: nowrap; | |
| opacity: 0; | |
| transition: opacity 0.05s; | |
| font-size: 11px; | |
| } | |
| .entity:hover .tooltip { | |
| visibility: visible; | |
| opacity: 1; | |
| } | |
| .entity.marked { | |
| background-color: rgba(255, 230, 0, 0.4); | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| # Load model | |
| tkn = os.getenv("tkn") | |
| tokenizer = AutoTokenizer.from_pretrained("harshildarji/JuraNER", use_auth_token=tkn) | |
| model = AutoModelForTokenClassification.from_pretrained( | |
| "harshildarji/JuraNER", use_auth_token=tkn | |
| ) | |
| ner = pipeline("ner", model=model, tokenizer=tokenizer) | |
| # Entity labels | |
| entity_labels = { | |
| "AN": "Lawyer", | |
| "EUN": "European legal norm", | |
| "GRT": "Court", | |
| "GS": "Law", | |
| "INN": "Institution", | |
| "LD": "Country", | |
| "LDS": "Landscape", | |
| "LIT": "Legal literature", | |
| "MRK": "Brand", | |
| "ORG": "Organization", | |
| "PER": "Person", | |
| "RR": "Judge", | |
| "RS": "Court decision", | |
| "ST": "City", | |
| "STR": "Street", | |
| "UN": "Company", | |
| "VO": "Ordinance", | |
| "VS": "Regulation", | |
| "VT": "Contract", | |
| } | |
| # Fixed colors | |
| def generate_fixed_colors(keys, alpha=0.25): | |
| cmap = cm.get_cmap("tab20", len(keys)) | |
| rgba_colors = {} | |
| for i, key in enumerate(keys): | |
| r, g, b, _ = cmap(i) | |
| rgba = f"rgba({int(r*255)}, {int(g*255)}, {int(b*255)}, {alpha})" | |
| rgba_colors[key] = rgba | |
| return rgba_colors | |
| ENTITY_COLORS = generate_fixed_colors(list(entity_labels.keys()), alpha=0.30) | |
| # UI | |
| st.markdown("#### German Legal NER") | |
| uploaded_file = st.file_uploader("Upload a .txt file", type="txt") | |
| threshold = st.slider("Confidence threshold:", 0.0, 1.0, 0.8, 0.01) | |
| st.markdown("---") | |
| # Merge logic | |
| def merge_entities(entities): | |
| if not entities: | |
| return [] | |
| ents = sorted(entities, key=lambda e: e["index"]) | |
| merged = [ents[0].copy()] | |
| merged[0]["score_sum"] = ents[0]["score"] | |
| merged[0]["count"] = 1 | |
| for ent in ents[1:]: | |
| prev = merged[-1] | |
| if ent["index"] == prev["index"] + 1: | |
| tok = ent["word"] | |
| if tok.startswith("##"): | |
| prev["word"] += tok[2:] | |
| else: | |
| prev["word"] += " " + tok | |
| prev["end"] = ent["end"] | |
| prev["index"] = ent["index"] | |
| prev["score_sum"] += ent["score"] | |
| prev["count"] += 1 | |
| else: | |
| prev["score"] = prev["score_sum"] / prev["count"] | |
| del prev["score_sum"] | |
| del prev["count"] | |
| new_ent = ent.copy() | |
| new_ent["score_sum"] = ent["score"] | |
| new_ent["count"] = 1 | |
| merged.append(new_ent) | |
| if "score_sum" in merged[-1]: | |
| merged[-1]["score"] = merged[-1]["score_sum"] / merged[-1]["count"] | |
| del merged[-1]["score_sum"] | |
| del merged[-1]["count"] | |
| final = [] | |
| for ent in merged: | |
| w = ent["word"].strip() | |
| w = re.sub(r"\s*\.\s*", ".", w) | |
| w = re.sub(r"\s*,\s*", ", ", w) | |
| w = re.sub(r"\s*/\s*", "/", w) | |
| w = w.strip(string.whitespace + string.punctuation) | |
| if len(w) > 1 and re.search(r"\w", w): | |
| cleaned = ent.copy() | |
| cleaned["word"] = w | |
| final.append(cleaned) | |
| return final | |
| # HTML highlighting | |
| def highlight_entities(line, merged_entities, threshold): | |
| html = "" | |
| last_end = 0 | |
| for ent in merged_entities: | |
| if ent["score"] < threshold: | |
| continue | |
| start, end = ent["start"], ent["end"] | |
| label = ent["entity"].split("-")[-1] | |
| label_desc = entity_labels.get(label, label) | |
| color = ENTITY_COLORS.get(label, "#cccccc") | |
| html += line[last_end:start] | |
| highlight_style = f"background-color:{color}; font-weight:600;" | |
| html += ( | |
| f'<span class="entity marked" style="{highlight_style}">' | |
| f'{ent["word"]}<span class="tooltip">{label_desc}</span></span>' | |
| ) | |
| last_end = end | |
| html += line[last_end:] | |
| return html | |
| if uploaded_file: | |
| raw_bytes = uploaded_file.read() | |
| encoding = detect(raw_bytes)["encoding"] | |
| if encoding is None: | |
| st.error("Could not detect file encoding.") | |
| else: | |
| text = raw_bytes.decode(encoding) | |
| for line in text.splitlines(): | |
| if not line.strip(): | |
| st.write("") | |
| continue | |
| tokens = ner(line) | |
| merged = merge_entities(tokens) | |
| html_line = highlight_entities(line, merged, threshold) | |
| st.markdown( | |
| f'<div style="margin:0;padding:0;line-height:1.4;">{html_line}</div>', | |
| unsafe_allow_html=True, | |
| ) |