Prateek0515's picture
Create app.py
7cc7f9f verified
raw
history blame
5.55 kB
import gradio as gr
import torch
import torch.nn as nn
import numpy as np
from transformers import AutoTokenizer, AutoModel
from torchcrf import CRF
from huggingface_hub import hf_hub_download
import PyPDF2
from docx import Document
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
return x + self.pe[:, :x.size(1)]
class VanillaTransformer(nn.Module):
def __init__(self, d_model=768, nhead=8, num_layers=3, dim_feedforward=2048, dropout=0.1):
super().__init__()
self.pos_encoder = PositionalEncoding(d_model, dropout)
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, activation='gelu', batch_first=True)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
def forward(self, src, src_key_padding_mask=None):
src = self.pos_encoder(src)
return self.transformer(src, src_key_padding_mask=src_key_padding_mask)
class HierarchicalLegalSegModel(nn.Module):
def __init__(self, longformer_model, num_labels, hidden_dim=768, transformer_layers=3, transformer_heads=8, dropout=0.1):
super().__init__()
self.longformer = longformer_model
self.hidden_dim = hidden_dim
self.vanilla_transformer = VanillaTransformer(d_model=hidden_dim, nhead=transformer_heads, num_layers=transformer_layers, dim_feedforward=hidden_dim*4, dropout=dropout)
self.classifier = nn.Linear(hidden_dim, num_labels)
self.crf = CRF(num_labels, batch_first=True)
self.dropout = nn.Dropout(dropout)
def encode_sentences(self, input_ids, attention_mask):
batch_size, num_sentences, max_seq_len = input_ids.shape
input_ids_flat = input_ids.view(-1, max_seq_len)
attention_mask_flat = attention_mask.view(-1, max_seq_len)
outputs = self.longformer(input_ids=input_ids_flat, attention_mask=attention_mask_flat)
cls_embeddings = outputs.last_hidden_state[:, 0, :]
return cls_embeddings.view(batch_size, num_sentences, self.hidden_dim)
def forward(self, input_ids, attention_mask, sentence_mask=None):
embeddings = self.encode_sentences(input_ids, attention_mask)
embeddings = self.dropout(embeddings)
output = self.vanilla_transformer(embeddings, src_key_padding_mask=~sentence_mask if sentence_mask is not None else None)
emissions = self.classifier(output)
return self.crf.decode(emissions, mask=sentence_mask)
device = torch.device("cpu")
tokenizer = AutoTokenizer.from_pretrained("lexlms/legal-longformer-base")
longformer = AutoModel.from_pretrained("lexlms/legal-longformer-base").to(device)
for param in longformer.parameters():
param.requires_grad = False
model = HierarchicalLegalSegModel(longformer, 7).to(device)
model_path = hf_hub_download(repo_id="Prateek0515/legal-document-segmentation", filename="model.pth")
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
id2label = {0: "Arguments of Petitioner", 1: "Arguments of Respondent", 2: "Decision", 3: "Facts", 4: "Issue", 5: "None", 6: "Reasoning"}
def extract_text_from_pdf(file):
reader = PyPDF2.PdfReader(file)
text = ""
for page in reader.pages:
text += page.extract_text()
return text.strip()
def extract_text_from_docx(file):
doc = Document(file)
return "\n".join([para.text for para in doc.paragraphs]).strip()
def predict(text_input, file_input):
try:
if file_input is not None:
if file_input.name.endswith('.pdf'):
text = extract_text_from_pdf(file_input.name)
elif file_input.name.endswith('.docx'):
text = extract_text_from_docx(file_input.name)
elif file_input.name.endswith('.txt'):
with open(file_input.name, 'r') as f:
text = f.read()
else:
return "❌ Unsupported file type"
else:
text = text_input
if not text:
return "⚠️ Please provide text"
encoded = tokenizer(text, padding="max_length", truncation=True, max_length=512, return_tensors="pt")
input_ids = encoded["input_ids"].unsqueeze(1).to(device)
attention_mask = encoded["attention_mask"].unsqueeze(1).to(device)
sentence_mask = torch.ones(1, 1, dtype=torch.bool).to(device)
with torch.no_grad():
predictions = model(input_ids, attention_mask, sentence_mask=sentence_mask)
label = id2label[predictions[0][0]]
return f"βœ… **Label:** {label}\n\nπŸ“„ **Text:** {text[:300]}..."
except Exception as e:
return f"❌ Error: {str(e)}"
demo = gr.Interface(fn=predict, inputs=[gr.Textbox(label="Enter Legal Text", lines=5), gr.File(label="Or Upload (PDF/DOCX/TXT)")], outputs=gr.Textbox(label="Result", lines=5), title="βš–οΈ Legal Document Segmentation", api_name="predict")
if __name__ == "__main__":
demo.launch()