Juristische-NER / app.py
harshildarji's picture
Update app.py
da37e40 verified
raw
history blame
5.55 kB
import os
import re
import string
import matplotlib.cm as cm
import streamlit as st
from charset_normalizer import detect
from transformers import (
AutoModelForTokenClassification,
AutoTokenizer,
logging,
pipeline,
)
st.set_page_config(page_title="German Legal NER", page_icon="⚖️", layout="wide")
logging.set_verbosity(logging.ERROR)
st.markdown(
"""
<style>
.block-container {
padding-top: 1rem;
padding-bottom: 5rem;
padding-left: 3rem;
padding-right: 3rem;
}
header, footer {visibility: hidden;}
.entity {
position: relative;
display: inline-block;
background-color: transparent;
font-weight: normal;
cursor: help;
}
.entity .tooltip {
visibility: hidden;
background-color: #333;
color: #fff;
text-align: center;
border-radius: 4px;
padding: 2px 6px;
position: absolute;
z-index: 1;
bottom: 125%;
left: 50%;
transform: translateX(-50%);
white-space: nowrap;
opacity: 0;
transition: opacity 0.05s;
font-size: 11px;
}
.entity:hover .tooltip {
visibility: visible;
opacity: 1;
}
.entity.marked {
background-color: rgba(255, 230, 0, 0.4);
}
</style>
""",
unsafe_allow_html=True,
)
# Load model
tkn = os.getenv("tkn")
tokenizer = AutoTokenizer.from_pretrained("harshildarji/JuraNER", use_auth_token=tkn)
model = AutoModelForTokenClassification.from_pretrained(
"harshildarji/JuraNER", use_auth_token=tkn
)
ner = pipeline("ner", model=model, tokenizer=tokenizer)
# Entity labels
entity_labels = {
"AN": "Lawyer",
"EUN": "European legal norm",
"GRT": "Court",
"GS": "Law",
"INN": "Institution",
"LD": "Country",
"LDS": "Landscape",
"LIT": "Legal literature",
"MRK": "Brand",
"ORG": "Organization",
"PER": "Person",
"RR": "Judge",
"RS": "Court decision",
"ST": "City",
"STR": "Street",
"UN": "Company",
"VO": "Ordinance",
"VS": "Regulation",
"VT": "Contract",
}
# Fixed colors
def generate_fixed_colors(keys, alpha=0.25):
cmap = cm.get_cmap("tab20", len(keys))
rgba_colors = {}
for i, key in enumerate(keys):
r, g, b, _ = cmap(i)
rgba = f"rgba({int(r*255)}, {int(g*255)}, {int(b*255)}, {alpha})"
rgba_colors[key] = rgba
return rgba_colors
ENTITY_COLORS = generate_fixed_colors(list(entity_labels.keys()), alpha=0.30)
# UI
st.markdown("#### German Legal NER")
uploaded_file = st.file_uploader("Upload a .txt file", type="txt")
threshold = st.slider("Confidence threshold:", 0.0, 1.0, 0.8, 0.01)
st.markdown("---")
# Merge logic
def merge_entities(entities):
if not entities:
return []
ents = sorted(entities, key=lambda e: e["index"])
merged = [ents[0].copy()]
merged[0]["score_sum"] = ents[0]["score"]
merged[0]["count"] = 1
for ent in ents[1:]:
prev = merged[-1]
if ent["index"] == prev["index"] + 1:
tok = ent["word"]
if tok.startswith("##"):
prev["word"] += tok[2:]
else:
prev["word"] += " " + tok
prev["end"] = ent["end"]
prev["index"] = ent["index"]
prev["score_sum"] += ent["score"]
prev["count"] += 1
else:
prev["score"] = prev["score_sum"] / prev["count"]
del prev["score_sum"]
del prev["count"]
new_ent = ent.copy()
new_ent["score_sum"] = ent["score"]
new_ent["count"] = 1
merged.append(new_ent)
if "score_sum" in merged[-1]:
merged[-1]["score"] = merged[-1]["score_sum"] / merged[-1]["count"]
del merged[-1]["score_sum"]
del merged[-1]["count"]
final = []
for ent in merged:
w = ent["word"].strip()
w = re.sub(r"\s*\.\s*", ".", w)
w = re.sub(r"\s*,\s*", ", ", w)
w = re.sub(r"\s*/\s*", "/", w)
w = w.strip(string.whitespace + string.punctuation)
if len(w) > 1 and re.search(r"\w", w):
cleaned = ent.copy()
cleaned["word"] = w
final.append(cleaned)
return final
# HTML highlighting
def highlight_entities(line, merged_entities, threshold):
html = ""
last_end = 0
for ent in merged_entities:
if ent["score"] < threshold:
continue
start, end = ent["start"], ent["end"]
label = ent["entity"].split("-")[-1]
label_desc = entity_labels.get(label, label)
color = ENTITY_COLORS.get(label, "#cccccc")
html += line[last_end:start]
highlight_style = f"background-color:{color}; font-weight:600;"
html += (
f'<span class="entity marked" style="{highlight_style}">'
f'{ent["word"]}<span class="tooltip">{label_desc}</span></span>'
)
last_end = end
html += line[last_end:]
return html
if uploaded_file:
raw_bytes = uploaded_file.read()
encoding = detect(raw_bytes)["encoding"]
if encoding is None:
st.error("Could not detect file encoding.")
else:
text = raw_bytes.decode(encoding)
for line in text.splitlines():
if not line.strip():
st.write("")
continue
tokens = ner(line)
merged = merge_entities(tokens)
html_line = highlight_entities(line, merged, threshold)
st.markdown(
f'<div style="margin:0;padding:0;line-height:1.4;">{html_line}</div>',
unsafe_allow_html=True,
)