Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -11,9 +11,9 @@ from loguru import logger
|
|
| 11 |
|
| 12 |
import aiohttp
|
| 13 |
import gradio as gr
|
| 14 |
-
|
| 15 |
-
from
|
| 16 |
-
|
| 17 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 18 |
|
| 19 |
import bibtexparser
|
|
@@ -32,7 +32,7 @@ class Config:
|
|
| 32 |
default_headers: dict = field(default_factory=lambda: {
|
| 33 |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
|
| 34 |
})
|
| 35 |
-
log_level: str = 'DEBUG'
|
| 36 |
|
| 37 |
class ArxivXmlParser:
|
| 38 |
NS = {
|
|
@@ -111,8 +111,8 @@ class ArxivXmlParser:
|
|
| 111 |
'year': year
|
| 112 |
}]
|
| 113 |
writer = BibTexWriter()
|
| 114 |
-
writer.indent = ' '
|
| 115 |
-
writer.comma_first = False
|
| 116 |
return writer.write(db).strip()
|
| 117 |
|
| 118 |
class AsyncContextManager:
|
|
@@ -135,13 +135,7 @@ class CitationGenerator:
|
|
| 135 |
google_api_key=config.gemini_api_key,
|
| 136 |
streaming=True
|
| 137 |
)
|
| 138 |
-
self.
|
| 139 |
-
self.generate_queries_chain = self._create_generate_queries_chain()
|
| 140 |
-
logger.remove()
|
| 141 |
-
logger.add(sys.stderr, level=config.log_level) # Configure logger
|
| 142 |
-
|
| 143 |
-
def _create_citation_chain(self):
|
| 144 |
-
citation_prompt = PromptTemplate.from_template(
|
| 145 |
"""Insert citations into the provided text using LaTeX \\cite{{key}} commands.
|
| 146 |
|
| 147 |
You must not alter the original wording or structure of the text beyond adding citations.
|
|
@@ -154,15 +148,8 @@ class CitationGenerator:
|
|
| 154 |
{papers}
|
| 155 |
"""
|
| 156 |
)
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
| citation_prompt
|
| 160 |
-
| self.llm
|
| 161 |
-
| StrOutputParser()
|
| 162 |
-
)
|
| 163 |
-
|
| 164 |
-
def _create_generate_queries_chain(self):
|
| 165 |
-
generate_queries_prompt = PromptTemplate.from_template(
|
| 166 |
"""Generate {num_queries} diverse academic search queries based on the given text.
|
| 167 |
The queries should be concise and relevant.
|
| 168 |
|
|
@@ -174,20 +161,18 @@ class CitationGenerator:
|
|
| 174 |
Text: {text}
|
| 175 |
"""
|
| 176 |
)
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
| self.llm
|
| 181 |
-
| StrOutputParser()
|
| 182 |
-
)
|
| 183 |
|
| 184 |
async def generate_queries(self, text: str, num_queries: int) -> List[str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
try:
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
"num_queries": num_queries
|
| 189 |
-
})
|
| 190 |
-
|
| 191 |
content = response.strip()
|
| 192 |
if not content.startswith('['):
|
| 193 |
start = content.find('[')
|
|
@@ -206,7 +191,7 @@ class CitationGenerator:
|
|
| 206 |
return ["deep learning neural networks"]
|
| 207 |
|
| 208 |
except Exception as e:
|
| 209 |
-
logger.error(f"Error generating queries: {e}")
|
| 210 |
return ["deep learning neural networks"]
|
| 211 |
|
| 212 |
async def search_arxiv(self, session: aiohttp.ClientSession, query: str, max_results: int) -> List[Dict]:
|
|
@@ -218,7 +203,6 @@ class CitationGenerator:
|
|
| 218 |
'sortBy': 'relevance',
|
| 219 |
'sortOrder': 'descending'
|
| 220 |
}
|
| 221 |
-
|
| 222 |
async with session.get(
|
| 223 |
self.config.arxiv_base_url + urllib.parse.urlencode(params),
|
| 224 |
headers=self.config.default_headers,
|
|
@@ -234,7 +218,6 @@ class CitationGenerator:
|
|
| 234 |
async def fix_author_name(self, author: str) -> str:
|
| 235 |
if not re.search(r'[�]', author):
|
| 236 |
return author
|
| 237 |
-
|
| 238 |
try:
|
| 239 |
prompt = f"""Fix this author name that contains corrupted characters (�):
|
| 240 |
|
|
@@ -244,18 +227,12 @@ class CitationGenerator:
|
|
| 244 |
1. Return ONLY the fixed author name
|
| 245 |
2. Use proper diacritical marks for names
|
| 246 |
3. Consider common name patterns and languages
|
| 247 |
-
4. If unsure
|
| 248 |
5. Maintain the format: "Lastname, Firstname"
|
| 249 |
-
|
| 250 |
-
Example fixes:
|
| 251 |
-
- "Gonz�lez" -> "González"
|
| 252 |
-
- "Cristi�n" -> "Cristi��n"
|
| 253 |
"""
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
fixed_name = response.content.strip()
|
| 257 |
return fixed_name if fixed_name else author
|
| 258 |
-
|
| 259 |
except Exception as e:
|
| 260 |
logger.error(f"Error fixing author name: {e}")
|
| 261 |
return author
|
|
@@ -276,7 +253,7 @@ class CitationGenerator:
|
|
| 276 |
writer.comma_first = False
|
| 277 |
return writer.write(bib_database).strip()
|
| 278 |
except Exception as e:
|
| 279 |
-
logger.error(f"Error cleaning BibTeX special characters: {e}")
|
| 280 |
return text
|
| 281 |
|
| 282 |
async def search_crossref(self, session: aiohttp.ClientSession, query: str, max_results: int) -> List[Dict]:
|
|
@@ -341,31 +318,24 @@ class CitationGenerator:
|
|
| 341 |
|
| 342 |
bibtex_text = await bibtex_response.text()
|
| 343 |
|
| 344 |
-
# Parse the BibTeX entry
|
| 345 |
bib_database = bibtexparser.loads(bibtex_text)
|
| 346 |
if not bib_database.entries:
|
| 347 |
continue
|
| 348 |
entry = bib_database.entries[0]
|
| 349 |
|
| 350 |
-
# Check if 'title' or 'booktitle' is present
|
| 351 |
if 'title' not in entry and 'booktitle' not in entry:
|
| 352 |
-
continue
|
| 353 |
-
|
| 354 |
-
# Check if 'author' is present
|
| 355 |
if 'author' not in entry:
|
| 356 |
-
continue
|
| 357 |
|
| 358 |
-
# Extract necessary fields
|
| 359 |
title = entry.get('title', 'No Title').replace('{', '').replace('}', '')
|
| 360 |
authors = entry.get('author', 'Unknown').replace('\n', ' ').replace('\t', ' ').strip()
|
| 361 |
year = entry.get('year', 'Unknown')
|
| 362 |
|
| 363 |
-
# Generate a unique BibTeX key
|
| 364 |
key = self._generate_unique_bibtex_key(entry, existing_keys)
|
| 365 |
entry['ID'] = key
|
| 366 |
existing_keys.add(key)
|
| 367 |
|
| 368 |
-
# Use BibTexWriter to format the entry
|
| 369 |
writer = BibTexWriter()
|
| 370 |
writer.indent = ' '
|
| 371 |
writer.comma_first = False
|
|
@@ -378,10 +348,9 @@ class CitationGenerator:
|
|
| 378 |
'bibtex_key': key,
|
| 379 |
'bibtex_entry': formatted_bibtex
|
| 380 |
})
|
| 381 |
-
except Exception as e:
|
| 382 |
-
logger.error(f"Error processing CrossRef item: {e}") # Replace print with logger
|
| 383 |
-
continue
|
| 384 |
|
|
|
|
|
|
|
| 385 |
return papers
|
| 386 |
|
| 387 |
except aiohttp.ClientError as e:
|
|
@@ -393,7 +362,7 @@ class CitationGenerator:
|
|
| 393 |
await asyncio.sleep(delay)
|
| 394 |
|
| 395 |
except Exception as e:
|
| 396 |
-
logger.error(f"Error searching CrossRef: {e}")
|
| 397 |
return []
|
| 398 |
|
| 399 |
def _generate_unique_bibtex_key(self, entry: Dict, existing_keys: set) -> str:
|
|
@@ -402,18 +371,15 @@ class CitationGenerator:
|
|
| 402 |
year = entry.get('year', '')
|
| 403 |
authors = [a.strip() for a in author_field.split(' and ')]
|
| 404 |
first_author_last_name = authors[0].split(',')[0] if authors else 'unknown'
|
| 405 |
-
|
| 406 |
if entry_type == 'inbook':
|
| 407 |
-
# Use 'booktitle' for 'inbook' entries
|
| 408 |
booktitle = entry.get('booktitle', '')
|
| 409 |
title_word = re.sub(r'\W+', '', booktitle.split()[0]) if booktitle else 'untitled'
|
| 410 |
else:
|
| 411 |
-
# Use regular 'title' for other entries
|
| 412 |
title = entry.get('title', '')
|
| 413 |
title_word = re.sub(r'\W+', '', title.split()[0]) if title else 'untitled'
|
| 414 |
-
|
| 415 |
base_key = f"{first_author_last_name}{year}{title_word}"
|
| 416 |
-
# Ensure the key is unique
|
| 417 |
key = base_key
|
| 418 |
index = 1
|
| 419 |
while key in existing_keys:
|
|
@@ -422,70 +388,93 @@ class CitationGenerator:
|
|
| 422 |
return key
|
| 423 |
|
| 424 |
async def process_text(self, text: str, num_queries: int, citations_per_query: int,
|
| 425 |
-
|
| 426 |
if not (use_arxiv or use_crossref):
|
| 427 |
return "Please select at least one source (ArXiv or CrossRef)", ""
|
| 428 |
|
| 429 |
num_queries = min(max(1, num_queries), self.config.max_queries)
|
| 430 |
citations_per_query = min(max(1, citations_per_query), self.config.max_citations_per_query)
|
| 431 |
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
return text, ""
|
| 435 |
-
|
| 436 |
-
async with self.async_context as session:
|
| 437 |
-
search_tasks = []
|
| 438 |
-
for query in queries:
|
| 439 |
-
if use_arxiv:
|
| 440 |
-
search_tasks.append(self.search_arxiv(session, query, citations_per_query))
|
| 441 |
-
if use_crossref:
|
| 442 |
-
search_tasks.append(self.search_crossref(session, query, citations_per_query))
|
| 443 |
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 461 |
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 466 |
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 467 |
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
logger.warning(f"Empty BibTeX entry for key: {p['bibtex_key']}")
|
| 477 |
-
writer = BibTexWriter()
|
| 478 |
-
writer.indent = ' '
|
| 479 |
-
writer.comma_first = False
|
| 480 |
-
bibtex_entries = writer.write(bib_database).strip()
|
| 481 |
-
|
| 482 |
-
return cited_text, bibtex_entries
|
| 483 |
-
except Exception as e:
|
| 484 |
-
logger.error(f"Error inserting citations: {e}") # Replace print with logger
|
| 485 |
-
return text, ""
|
| 486 |
|
| 487 |
def create_gradio_interface() -> gr.Interface:
|
| 488 |
-
# Removed CitationGenerator initialization here
|
| 489 |
async def process(api_key: str, text: str, num_queries: int, citations_per_query: int,
|
| 490 |
use_arxiv: bool, use_crossref: bool) -> tuple[str, str]:
|
| 491 |
if not api_key.strip():
|
|
@@ -600,12 +589,11 @@ def create_gradio_interface() -> gr.Interface:
|
|
| 600 |
|
| 601 |
with gr.Blocks(css=css, theme=gr.themes.Default()) as demo:
|
| 602 |
gr.HTML("""<div class="header">
|
| 603 |
-
<h1>📚 AutoCitation</h1>
|
| 604 |
-
<p>Insert citations into your academic text</p>
|
| 605 |
-
</div>""")
|
| 606 |
|
| 607 |
with gr.Group(elem_classes="input-group"):
|
| 608 |
-
# Added API Key input field
|
| 609 |
api_key = gr.Textbox(
|
| 610 |
label="Gemini API Key",
|
| 611 |
placeholder="Enter your Gemini API key...",
|
|
@@ -623,7 +611,7 @@ def create_gradio_interface() -> gr.Interface:
|
|
| 623 |
label="Search Queries",
|
| 624 |
value=3,
|
| 625 |
minimum=1,
|
| 626 |
-
maximum=
|
| 627 |
step=1
|
| 628 |
)
|
| 629 |
with gr.Column(scale=1):
|
|
@@ -631,7 +619,7 @@ def create_gradio_interface() -> gr.Interface:
|
|
| 631 |
label="Citations per Query",
|
| 632 |
value=1,
|
| 633 |
minimum=1,
|
| 634 |
-
maximum=
|
| 635 |
step=1
|
| 636 |
)
|
| 637 |
|
|
@@ -669,7 +657,6 @@ def create_gradio_interface() -> gr.Interface:
|
|
| 669 |
show_copy_button=True
|
| 670 |
)
|
| 671 |
|
| 672 |
-
# Updated the inputs and outputs
|
| 673 |
process_btn.click(
|
| 674 |
fn=process,
|
| 675 |
inputs=[api_key, input_text, num_queries, citations_per_query, use_arxiv, use_crossref],
|
|
@@ -679,6 +666,10 @@ def create_gradio_interface() -> gr.Interface:
|
|
| 679 |
return demo
|
| 680 |
|
| 681 |
if __name__ == "__main__":
|
| 682 |
-
|
| 683 |
demo = create_gradio_interface()
|
| 684 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
import aiohttp
|
| 13 |
import gradio as gr
|
| 14 |
+
|
| 15 |
+
from langchain.prompts import PromptTemplate
|
| 16 |
+
|
| 17 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 18 |
|
| 19 |
import bibtexparser
|
|
|
|
| 32 |
default_headers: dict = field(default_factory=lambda: {
|
| 33 |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
|
| 34 |
})
|
| 35 |
+
log_level: str = 'DEBUG'
|
| 36 |
|
| 37 |
class ArxivXmlParser:
|
| 38 |
NS = {
|
|
|
|
| 111 |
'year': year
|
| 112 |
}]
|
| 113 |
writer = BibTexWriter()
|
| 114 |
+
writer.indent = ' '
|
| 115 |
+
writer.comma_first = False
|
| 116 |
return writer.write(db).strip()
|
| 117 |
|
| 118 |
class AsyncContextManager:
|
|
|
|
| 135 |
google_api_key=config.gemini_api_key,
|
| 136 |
streaming=True
|
| 137 |
)
|
| 138 |
+
self.citation_prompt = PromptTemplate.from_template(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
"""Insert citations into the provided text using LaTeX \\cite{{key}} commands.
|
| 140 |
|
| 141 |
You must not alter the original wording or structure of the text beyond adding citations.
|
|
|
|
| 148 |
{papers}
|
| 149 |
"""
|
| 150 |
)
|
| 151 |
+
|
| 152 |
+
self.generate_queries_prompt = PromptTemplate.from_template(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
"""Generate {num_queries} diverse academic search queries based on the given text.
|
| 154 |
The queries should be concise and relevant.
|
| 155 |
|
|
|
|
| 161 |
Text: {text}
|
| 162 |
"""
|
| 163 |
)
|
| 164 |
+
|
| 165 |
+
logger.remove()
|
| 166 |
+
logger.add(sys.stderr, level=config.log_level)
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
async def generate_queries(self, text: str, num_queries: int) -> List[str]:
|
| 169 |
+
input_map = {
|
| 170 |
+
"text": text,
|
| 171 |
+
"num_queries": num_queries
|
| 172 |
+
}
|
| 173 |
try:
|
| 174 |
+
prompt = self.generate_queries_prompt.format(**input_map)
|
| 175 |
+
response = await self.llm.apredict(prompt)
|
|
|
|
|
|
|
|
|
|
| 176 |
content = response.strip()
|
| 177 |
if not content.startswith('['):
|
| 178 |
start = content.find('[')
|
|
|
|
| 191 |
return ["deep learning neural networks"]
|
| 192 |
|
| 193 |
except Exception as e:
|
| 194 |
+
logger.error(f"Error generating queries: {e}")
|
| 195 |
return ["deep learning neural networks"]
|
| 196 |
|
| 197 |
async def search_arxiv(self, session: aiohttp.ClientSession, query: str, max_results: int) -> List[Dict]:
|
|
|
|
| 203 |
'sortBy': 'relevance',
|
| 204 |
'sortOrder': 'descending'
|
| 205 |
}
|
|
|
|
| 206 |
async with session.get(
|
| 207 |
self.config.arxiv_base_url + urllib.parse.urlencode(params),
|
| 208 |
headers=self.config.default_headers,
|
|
|
|
| 218 |
async def fix_author_name(self, author: str) -> str:
|
| 219 |
if not re.search(r'[�]', author):
|
| 220 |
return author
|
|
|
|
| 221 |
try:
|
| 222 |
prompt = f"""Fix this author name that contains corrupted characters (�):
|
| 223 |
|
|
|
|
| 227 |
1. Return ONLY the fixed author name
|
| 228 |
2. Use proper diacritical marks for names
|
| 229 |
3. Consider common name patterns and languages
|
| 230 |
+
4. If unsure, use the most likely letter
|
| 231 |
5. Maintain the format: "Lastname, Firstname"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
"""
|
| 233 |
+
response = await self.llm.apredict(prompt)
|
| 234 |
+
fixed_name = response.strip()
|
|
|
|
| 235 |
return fixed_name if fixed_name else author
|
|
|
|
| 236 |
except Exception as e:
|
| 237 |
logger.error(f"Error fixing author name: {e}")
|
| 238 |
return author
|
|
|
|
| 253 |
writer.comma_first = False
|
| 254 |
return writer.write(bib_database).strip()
|
| 255 |
except Exception as e:
|
| 256 |
+
logger.error(f"Error cleaning BibTeX special characters: {e}")
|
| 257 |
return text
|
| 258 |
|
| 259 |
async def search_crossref(self, session: aiohttp.ClientSession, query: str, max_results: int) -> List[Dict]:
|
|
|
|
| 318 |
|
| 319 |
bibtex_text = await bibtex_response.text()
|
| 320 |
|
|
|
|
| 321 |
bib_database = bibtexparser.loads(bibtex_text)
|
| 322 |
if not bib_database.entries:
|
| 323 |
continue
|
| 324 |
entry = bib_database.entries[0]
|
| 325 |
|
|
|
|
| 326 |
if 'title' not in entry and 'booktitle' not in entry:
|
| 327 |
+
continue
|
|
|
|
|
|
|
| 328 |
if 'author' not in entry:
|
| 329 |
+
continue
|
| 330 |
|
|
|
|
| 331 |
title = entry.get('title', 'No Title').replace('{', '').replace('}', '')
|
| 332 |
authors = entry.get('author', 'Unknown').replace('\n', ' ').replace('\t', ' ').strip()
|
| 333 |
year = entry.get('year', 'Unknown')
|
| 334 |
|
|
|
|
| 335 |
key = self._generate_unique_bibtex_key(entry, existing_keys)
|
| 336 |
entry['ID'] = key
|
| 337 |
existing_keys.add(key)
|
| 338 |
|
|
|
|
| 339 |
writer = BibTexWriter()
|
| 340 |
writer.indent = ' '
|
| 341 |
writer.comma_first = False
|
|
|
|
| 348 |
'bibtex_key': key,
|
| 349 |
'bibtex_entry': formatted_bibtex
|
| 350 |
})
|
|
|
|
|
|
|
|
|
|
| 351 |
|
| 352 |
+
except Exception as e:
|
| 353 |
+
logger.error(f"Error processing CrossRef item: {e}")
|
| 354 |
return papers
|
| 355 |
|
| 356 |
except aiohttp.ClientError as e:
|
|
|
|
| 362 |
await asyncio.sleep(delay)
|
| 363 |
|
| 364 |
except Exception as e:
|
| 365 |
+
logger.error(f"Error searching CrossRef: {e}")
|
| 366 |
return []
|
| 367 |
|
| 368 |
def _generate_unique_bibtex_key(self, entry: Dict, existing_keys: set) -> str:
|
|
|
|
| 371 |
year = entry.get('year', '')
|
| 372 |
authors = [a.strip() for a in author_field.split(' and ')]
|
| 373 |
first_author_last_name = authors[0].split(',')[0] if authors else 'unknown'
|
| 374 |
+
|
| 375 |
if entry_type == 'inbook':
|
|
|
|
| 376 |
booktitle = entry.get('booktitle', '')
|
| 377 |
title_word = re.sub(r'\W+', '', booktitle.split()[0]) if booktitle else 'untitled'
|
| 378 |
else:
|
|
|
|
| 379 |
title = entry.get('title', '')
|
| 380 |
title_word = re.sub(r'\W+', '', title.split()[0]) if title else 'untitled'
|
| 381 |
+
|
| 382 |
base_key = f"{first_author_last_name}{year}{title_word}"
|
|
|
|
| 383 |
key = base_key
|
| 384 |
index = 1
|
| 385 |
while key in existing_keys:
|
|
|
|
| 388 |
return key
|
| 389 |
|
| 390 |
async def process_text(self, text: str, num_queries: int, citations_per_query: int,
|
| 391 |
+
use_arxiv: bool = True, use_crossref: bool = True) -> tuple[str, str]:
|
| 392 |
if not (use_arxiv or use_crossref):
|
| 393 |
return "Please select at least one source (ArXiv or CrossRef)", ""
|
| 394 |
|
| 395 |
num_queries = min(max(1, num_queries), self.config.max_queries)
|
| 396 |
citations_per_query = min(max(1, citations_per_query), self.config.max_citations_per_query)
|
| 397 |
|
| 398 |
+
async def generate_queries_tool(input_data: dict):
|
| 399 |
+
return await self.generate_queries(input_data["text"], input_data["num_queries"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
|
| 401 |
+
async def search_papers_tool(input_data: dict):
|
| 402 |
+
queries = input_data["queries"]
|
| 403 |
+
papers = []
|
| 404 |
+
async with self.async_context as session:
|
| 405 |
+
search_tasks = []
|
| 406 |
+
for q in queries:
|
| 407 |
+
if input_data["use_arxiv"]:
|
| 408 |
+
search_tasks.append(self.search_arxiv(session, q, input_data["citations_per_query"]))
|
| 409 |
+
if input_data["use_crossref"]:
|
| 410 |
+
search_tasks.append(self.search_crossref(session, q, input_data["citations_per_query"]))
|
| 411 |
+
results = await asyncio.gather(*search_tasks, return_exceptions=True)
|
| 412 |
+
for r in results:
|
| 413 |
+
if not isinstance(r, Exception):
|
| 414 |
+
papers.extend(r)
|
| 415 |
+
# Deduplicate
|
| 416 |
+
unique_papers = []
|
| 417 |
+
seen_keys = set()
|
| 418 |
+
for p in papers:
|
| 419 |
+
if p['bibtex_key'] not in seen_keys:
|
| 420 |
+
seen_keys.add(p['bibtex_key'])
|
| 421 |
+
unique_papers.append(p)
|
| 422 |
+
return unique_papers
|
| 423 |
|
| 424 |
+
async def cite_text_tool(input_data: dict):
|
| 425 |
+
try:
|
| 426 |
+
citation_input = {
|
| 427 |
+
"text": input_data["text"],
|
| 428 |
+
"papers": json.dumps(input_data["papers"], indent=2)
|
| 429 |
+
}
|
| 430 |
+
prompt = self.citation_prompt.format(**citation_input)
|
| 431 |
+
response = await self.llm.apredict(prompt)
|
| 432 |
+
cited_text = response.strip()
|
| 433 |
+
|
| 434 |
+
# Aggregate BibTeX entries
|
| 435 |
+
bib_database = BibDatabase()
|
| 436 |
+
for p in input_data["papers"]:
|
| 437 |
+
if 'bibtex_entry' in p:
|
| 438 |
+
bib_db = bibtexparser.loads(p['bibtex_entry'])
|
| 439 |
+
if bib_db.entries:
|
| 440 |
+
bib_database.entries.append(bib_db.entries[0])
|
| 441 |
+
else:
|
| 442 |
+
logger.warning(f"Empty BibTeX entry for key: {p['bibtex_key']}")
|
| 443 |
+
writer = BibTexWriter()
|
| 444 |
+
writer.indent = ' '
|
| 445 |
+
writer.comma_first = False
|
| 446 |
+
bibtex_entries = writer.write(bib_database).strip()
|
| 447 |
+
return cited_text, bibtex_entries
|
| 448 |
+
except Exception as e:
|
| 449 |
+
logger.error(f"Error inserting citations: {e}")
|
| 450 |
+
return input_data["text"], ""
|
| 451 |
+
|
| 452 |
+
async def agent_run(input_data: dict):
|
| 453 |
+
queries = await generate_queries_tool(input_data)
|
| 454 |
+
papers = await search_papers_tool({
|
| 455 |
+
"queries": queries,
|
| 456 |
+
"citations_per_query": input_data["citations_per_query"],
|
| 457 |
+
"use_arxiv": input_data["use_arxiv"],
|
| 458 |
+
"use_crossref": input_data["use_crossref"]
|
| 459 |
})
|
| 460 |
+
if not papers:
|
| 461 |
+
return input_data["text"], ""
|
| 462 |
+
cited_text, final_bibtex = await cite_text_tool({
|
| 463 |
+
"text": input_data["text"],
|
| 464 |
+
"papers": papers
|
| 465 |
+
})
|
| 466 |
+
return cited_text, final_bibtex
|
| 467 |
|
| 468 |
+
final_text, final_bibtex = await agent_run({
|
| 469 |
+
"text": text,
|
| 470 |
+
"num_queries": num_queries,
|
| 471 |
+
"citations_per_query": citations_per_query,
|
| 472 |
+
"use_arxiv": use_arxiv,
|
| 473 |
+
"use_crossref": use_crossref
|
| 474 |
+
})
|
| 475 |
+
return final_text, final_bibtex
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 476 |
|
| 477 |
def create_gradio_interface() -> gr.Interface:
|
|
|
|
| 478 |
async def process(api_key: str, text: str, num_queries: int, citations_per_query: int,
|
| 479 |
use_arxiv: bool, use_crossref: bool) -> tuple[str, str]:
|
| 480 |
if not api_key.strip():
|
|
|
|
| 589 |
|
| 590 |
with gr.Blocks(css=css, theme=gr.themes.Default()) as demo:
|
| 591 |
gr.HTML("""<div class="header">
|
| 592 |
+
<h1>📚 AutoCitation</h1>
|
| 593 |
+
<p>Insert citations into your academic text</p>
|
| 594 |
+
</div>""")
|
| 595 |
|
| 596 |
with gr.Group(elem_classes="input-group"):
|
|
|
|
| 597 |
api_key = gr.Textbox(
|
| 598 |
label="Gemini API Key",
|
| 599 |
placeholder="Enter your Gemini API key...",
|
|
|
|
| 611 |
label="Search Queries",
|
| 612 |
value=3,
|
| 613 |
minimum=1,
|
| 614 |
+
maximum=Config.max_queries,
|
| 615 |
step=1
|
| 616 |
)
|
| 617 |
with gr.Column(scale=1):
|
|
|
|
| 619 |
label="Citations per Query",
|
| 620 |
value=1,
|
| 621 |
minimum=1,
|
| 622 |
+
maximum=Config.max_citations_per_query,
|
| 623 |
step=1
|
| 624 |
)
|
| 625 |
|
|
|
|
| 657 |
show_copy_button=True
|
| 658 |
)
|
| 659 |
|
|
|
|
| 660 |
process_btn.click(
|
| 661 |
fn=process,
|
| 662 |
inputs=[api_key, input_text, num_queries, citations_per_query, use_arxiv, use_crossref],
|
|
|
|
| 666 |
return demo
|
| 667 |
|
| 668 |
if __name__ == "__main__":
|
|
|
|
| 669 |
demo = create_gradio_interface()
|
| 670 |
+
try:
|
| 671 |
+
demo.launch(server_port=7860, share=False)
|
| 672 |
+
except KeyboardInterrupt:
|
| 673 |
+
print("\nShutting down server...")
|
| 674 |
+
except Exception as e:
|
| 675 |
+
print(f"Error starting server: {str(e)}")
|