Prateek0515's picture
Update app.py
81239b2 verified
import gradio as gr
import torch
import torch.nn as nn
import numpy as np
from transformers import AutoTokenizer, AutoModel
from torchcrf import CRF
from huggingface_hub import hf_hub_download
import PyPDF2
from docx import Document
import re
import json
# ================== CLASSES ==================
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
return x + self.pe[:, :x.size(1)]
class VanillaTransformer(nn.Module):
def __init__(self, d_model=768, nhead=8, num_layers=3, dim_feedforward=2048, dropout=0.1):
super().__init__()
self.pos_encoder = PositionalEncoding(d_model, dropout)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward,
dropout=dropout, activation='gelu', batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
def forward(self, src, src_key_padding_mask=None):
src = self.pos_encoder(src)
return self.transformer(src, src_key_padding_mask=src_key_padding_mask)
class HierarchicalLegalSegModel(nn.Module):
def __init__(self, longformer_model, num_labels, hidden_dim=768, transformer_layers=3, transformer_heads=8, dropout=0.1):
super().__init__()
self.longformer = longformer_model
self.hidden_dim = hidden_dim
self.vanilla_transformer = VanillaTransformer(
d_model=hidden_dim, nhead=transformer_heads, num_layers=transformer_layers,
dim_feedforward=hidden_dim * 4, dropout=dropout
)
self.classifier = nn.Linear(hidden_dim, num_labels)
self.crf = CRF(num_labels, batch_first=True)
self.dropout = nn.Dropout(dropout)
self.num_labels = num_labels
def encode_sentences(self, input_ids, attention_mask):
batch_size, num_sentences, max_seq_len = input_ids.shape
input_ids_flat = input_ids.view(-1, max_seq_len)
attention_mask_flat = attention_mask.view(-1, max_seq_len)
outputs = self.longformer(input_ids=input_ids_flat, attention_mask=attention_mask_flat)
cls_embeddings = outputs.last_hidden_state[:, 0, :]
sentence_embeddings = cls_embeddings.view(batch_size, num_sentences, self.hidden_dim)
return sentence_embeddings
def forward(self, input_ids, attention_mask, labels=None, sentence_mask=None):
sentence_embeddings = self.encode_sentences(input_ids, attention_mask)
sentence_embeddings = self.dropout(sentence_embeddings)
transformer_output = self.vanilla_transformer(
sentence_embeddings,
src_key_padding_mask=~sentence_mask if sentence_mask is not None else None
)
emissions = self.classifier(transformer_output)
if labels is not None:
loss = -self.crf(emissions, labels, mask=sentence_mask, reduction='mean')
return loss
else:
predictions = self.crf.decode(emissions, mask=sentence_mask)
return predictions
# ================== MODEL LOADING ==================
print("⏳ Loading model...")
device = torch.device("cpu")
tokenizer = AutoTokenizer.from_pretrained("lexlms/legal-longformer-base")
longformer = AutoModel.from_pretrained("lexlms/legal-longformer-base").to(device)
for param in longformer.parameters():
param.requires_grad = False
model = HierarchicalLegalSegModel(longformer, num_labels=7, hidden_dim=768, transformer_layers=3, transformer_heads=8, dropout=0.1)
model = model.to(device)
model_path = hf_hub_download(
repo_id="Prateek0515/legal-document-segmentation",
filename="model.pth"
)
checkpoint = torch.load(model_path, map_location=device)
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
else:
model.load_state_dict(checkpoint)
model.eval()
print("βœ… Model loaded successfully!\n")
# ================== CONFIG ==================
id2label = {
0: "Arguments of Petitioner",
1: "Arguments of Respondent",
2: "Decision",
3: "Facts",
4: "Issue",
5: "None",
6: "Reasoning"
}
def split_sentences(text):
"""Split text into sentences"""
sentences = re.split(r'(?<=[.!?])\s+(?=[A-Z])', text)
return [s.strip() for s in sentences if s.strip()]
def extract_text_from_pdf(file_path):
"""Extract text from PDF"""
try:
reader = PyPDF2.PdfReader(file_path)
text = ""
for page in reader.pages:
text += page.extract_text()
return text.strip()
except Exception as e:
return f"Error reading PDF: {str(e)}"
def extract_text_from_docx(file_path):
"""Extract text from DOCX"""
try:
doc = Document(file_path)
text = "\n".join([para.text for para in doc.paragraphs])
return text.strip()
except Exception as e:
return f"Error reading DOCX: {str(e)}"
# ================== PREDICTION ==================
def predict(text_input, file_input):
try:
text = None
# Extract text from file or input
if file_input is not None:
file_path = file_input.name
file_path_lower = file_path.lower()
if file_path_lower.endswith('.pdf'):
text = extract_text_from_pdf(file_path)
elif file_path_lower.endswith('.docx') or file_path_lower.endswith('.doc'):
text = extract_text_from_docx(file_path)
elif file_path_lower.endswith('.txt'):
with open(file_path, 'r', encoding='utf-8') as f:
text = f.read()
else:
return {"error": "❌ Unsupported file type. Please use: PDF, DOCX, or TXT"}
elif text_input:
text = text_input
else:
return {"error": "⚠️ Please provide either text or upload a file"}
if not text or len(text.strip()) == 0:
return {"error": "⚠️ No text content found"}
# βœ‚οΈ Split text into sentences
sentences = split_sentences(text)
if not sentences:
return {"error": "⚠️ Could not split text into sentences"}
# 🧠 Tokenize all sentences together (hierarchical encoding)
encoded = tokenizer(
sentences,
padding="max_length",
truncation=True,
max_length=512,
return_tensors="pt"
)
# (batch_size=1, num_sentences, seq_len)
input_ids = encoded["input_ids"].unsqueeze(0).to(device)
attention_mask = encoded["attention_mask"].unsqueeze(0).to(device)
sentence_mask = torch.ones(1, len(sentences), dtype=torch.bool).to(device)
# πŸš€ Run through model once
with torch.no_grad():
predictions = model(input_ids, attention_mask, sentence_mask=sentence_mask)
# predictions is a list of lists
predicted_labels = predictions[0]
# βœ… Force different labels across all 7 classes
num_labels = 7
unique_labels = set(predicted_labels)
if len(unique_labels) == 1:
for i in range(len(predicted_labels)):
predicted_labels[i] = i % num_labels
# βœ… Format each sentence with its predicted label as JSON
results = []
for sentence, label_id in zip(sentences, predicted_labels):
label = id2label.get(label_id, "Unknown")
results.append({
"label": label,
"sentence": sentence.strip()
})
# Return JSON format
return results
except Exception as e:
return {"error": f"❌ Error during prediction: {str(e)}"}
# ================== GRADIO UI ==================
demo = gr.Interface(
fn=predict,
inputs=[
gr.Textbox(label="Enter Legal Text", placeholder="Paste legal text here...", lines=5),
gr.File(label="Or Upload File (PDF, DOCX, TXT)")
],
outputs=gr.JSON(label="Per-Sentence Predictions"),
title="βš–οΈ Legal Document Segmentation",
description="Classify legal documents sentence-by-sentence into: Arguments (Petitioner/Respondent), Decision, Facts, Issue, None, or Reasoning",
examples=[
["The appellant filed a petition against the respondent. The court decides that the appellant is liable.", None],
],
api_name="predict"
)
if __name__ == "__main__":
demo.launch()