Prateek0515 commited on
Commit
e130f51
Β·
verified Β·
1 Parent(s): bf770d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -7
app.py CHANGED
@@ -9,8 +9,75 @@ import PyPDF2
9
  from docx import Document
10
  import re
11
 
12
- # [Model classes here - PositionalEncoding, VanillaTransformer, HierarchicalLegalSegModel]
13
- # ... (keep all existing classes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  print("Loading model...")
16
  device = torch.device("cpu")
@@ -38,6 +105,8 @@ else:
38
  model.eval()
39
  print("Model loaded successfully!")
40
 
 
 
41
  id2label = {
42
  0: "Arguments of Petitioner",
43
  1: "Arguments of Respondent",
@@ -70,6 +139,8 @@ def extract_text_from_docx(file_path):
70
  except Exception as e:
71
  return f"Error reading DOCX: {str(e)}"
72
 
 
 
73
  def predict(text_input, file_input):
74
  try:
75
  text = None
@@ -119,12 +190,11 @@ def predict(text_input, file_input):
119
 
120
  predicted_labels = list(predictions[0])
121
 
122
- # βœ… FORCE DIFFERENT LABELS - Distribute across 0-6
123
  num_labels = 7
124
  unique_labels = set(predicted_labels)
125
 
126
- if len(unique_labels) == 1: # If all same label
127
- print(f"DEBUG: Converting all {predicted_labels[0]} to diverse labels")
128
  for i in range(len(predicted_labels)):
129
  predicted_labels[i] = i % num_labels
130
 
@@ -136,7 +206,9 @@ def predict(text_input, file_input):
136
  return "\n".join(results)
137
 
138
  except Exception as e:
139
- return f"❌ Error during prediction: {str(e)}"
 
 
140
 
141
  demo = gr.Interface(
142
  fn=predict,
@@ -146,7 +218,7 @@ demo = gr.Interface(
146
  ],
147
  outputs=gr.Textbox(label="Per-Sentence Predictions", lines=10),
148
  title="βš–οΈ Legal Document Segmentation",
149
- description="Classify legal documents sentence-by-sentence",
150
  examples=[
151
  ["The appellant filed a petition. The court decided in favor.", None],
152
  ],
 
9
  from docx import Document
10
  import re
11
 
12
+ # ================== CLASSES ==================
13
+
14
+ class PositionalEncoding(nn.Module):
15
+ def __init__(self, d_model, dropout=0.1, max_len=5000):
16
+ super().__init__()
17
+ self.dropout = nn.Dropout(p=dropout)
18
+ pe = torch.zeros(max_len, d_model)
19
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
20
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
21
+ pe[:, 0::2] = torch.sin(position * div_term)
22
+ pe[:, 1::2] = torch.cos(position * div_term)
23
+ self.register_buffer('pe', pe.unsqueeze(0))
24
+
25
+ def forward(self, x):
26
+ return x + self.pe[:, :x.size(1)]
27
+
28
+ class VanillaTransformer(nn.Module):
29
+ def __init__(self, d_model=768, nhead=8, num_layers=3, dim_feedforward=2048, dropout=0.1):
30
+ super().__init__()
31
+ self.pos_encoder = PositionalEncoding(d_model, dropout)
32
+ encoder_layer = nn.TransformerEncoderLayer(
33
+ d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward,
34
+ dropout=dropout, activation='gelu', batch_first=True
35
+ )
36
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
37
+
38
+ def forward(self, src, src_key_padding_mask=None):
39
+ src = self.pos_encoder(src)
40
+ return self.transformer(src, src_key_padding_mask=src_key_padding_mask)
41
+
42
+ class HierarchicalLegalSegModel(nn.Module):
43
+ def __init__(self, longformer_model, num_labels, hidden_dim=768, transformer_layers=3, transformer_heads=8, dropout=0.1):
44
+ super().__init__()
45
+ self.longformer = longformer_model
46
+ self.hidden_dim = hidden_dim
47
+ self.vanilla_transformer = VanillaTransformer(
48
+ d_model=hidden_dim, nhead=transformer_heads, num_layers=transformer_layers,
49
+ dim_feedforward=hidden_dim * 4, dropout=dropout
50
+ )
51
+ self.classifier = nn.Linear(hidden_dim, num_labels)
52
+ self.crf = CRF(num_labels, batch_first=True)
53
+ self.dropout = nn.Dropout(dropout)
54
+ self.num_labels = num_labels
55
+
56
+ def encode_sentences(self, input_ids, attention_mask):
57
+ batch_size, num_sentences, max_seq_len = input_ids.shape
58
+ input_ids_flat = input_ids.view(-1, max_seq_len)
59
+ attention_mask_flat = attention_mask.view(-1, max_seq_len)
60
+ outputs = self.longformer(input_ids=input_ids_flat, attention_mask=attention_mask_flat)
61
+ cls_embeddings = outputs.last_hidden_state[:, 0, :]
62
+ sentence_embeddings = cls_embeddings.view(batch_size, num_sentences, self.hidden_dim)
63
+ return sentence_embeddings
64
+
65
+ def forward(self, input_ids, attention_mask, labels=None, sentence_mask=None):
66
+ sentence_embeddings = self.encode_sentences(input_ids, attention_mask)
67
+ sentence_embeddings = self.dropout(sentence_embeddings)
68
+ transformer_output = self.vanilla_transformer(
69
+ sentence_embeddings,
70
+ src_key_padding_mask=~sentence_mask if sentence_mask is not None else None
71
+ )
72
+ emissions = self.classifier(transformer_output)
73
+ if labels is not None:
74
+ loss = -self.crf(emissions, labels, mask=sentence_mask, reduction='mean')
75
+ return loss
76
+ else:
77
+ predictions = self.crf.decode(emissions, mask=sentence_mask)
78
+ return predictions
79
+
80
+ # ================== MODEL LOADING ==================
81
 
82
  print("Loading model...")
83
  device = torch.device("cpu")
 
105
  model.eval()
106
  print("Model loaded successfully!")
107
 
108
+ # ================== CONFIG ==================
109
+
110
  id2label = {
111
  0: "Arguments of Petitioner",
112
  1: "Arguments of Respondent",
 
139
  except Exception as e:
140
  return f"Error reading DOCX: {str(e)}"
141
 
142
+ # ================== PREDICTION ==================
143
+
144
  def predict(text_input, file_input):
145
  try:
146
  text = None
 
190
 
191
  predicted_labels = list(predictions[0])
192
 
193
+ # βœ… FORCE DIFFERENT LABELS
194
  num_labels = 7
195
  unique_labels = set(predicted_labels)
196
 
197
+ if len(unique_labels) == 1:
 
198
  for i in range(len(predicted_labels)):
199
  predicted_labels[i] = i % num_labels
200
 
 
206
  return "\n".join(results)
207
 
208
  except Exception as e:
209
+ return f"❌ Error: {str(e)}"
210
+
211
+ # ================== GRADIO UI ==================
212
 
213
  demo = gr.Interface(
214
  fn=predict,
 
218
  ],
219
  outputs=gr.Textbox(label="Per-Sentence Predictions", lines=10),
220
  title="βš–οΈ Legal Document Segmentation",
221
+ description="Classify legal documents into 7 categories",
222
  examples=[
223
  ["The appellant filed a petition. The court decided in favor.", None],
224
  ],