arxiv-search / core.py
eremeev-d's picture
More answers button
832fef5
import streamlit as st
from huggingface_hub import HfApi, HfFolder
import datasets
import pandas as pd
import logging
import os
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
@st.cache_data
def login():
if 'logged' not in st.session_state:
logging.info("Trying to log in to HF")
st.session_state['logged'] = True
HF_TOKEN = os.environ.get("HF_TOKEN")
api = HfApi()
api.set_access_token(HF_TOKEN)
folder = HfFolder()
folder.save_token(HF_TOKEN)
logging.info("Succesfully logged")
return True
else:
logging.info("Already logged in")
return False
@st.cache_resource
def load_model():
logging.info("Trying to load model")
tokenizer = AutoTokenizer.from_pretrained(
'sentence-transformers/all-MiniLM-L6-v2')
model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
logging.info("Model loaded")
return model, tokenizer
@st.cache_resource
def load_index():
logging.info("Trying to load index")
index = datasets.Dataset.load_from_disk("Data/articles.hf")
logging.info("Articles dataset loaded")
index.load_faiss_index("embedding", "Data/articles.index")
logging.info("FAISS index loaded")
return index
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1) \
.expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) \
/ torch.clamp(input_mask_expanded.sum(1), min=1e-9)
def get_embedding(query, model, tokenizer):
encoded_input = tokenizer(
query, padding=True, truncation=True, return_tensors='pt')
with torch.no_grad():
embeds = model(**encoded_input)
embeds = mean_pooling(embeds, encoded_input['attention_mask'])
embeds = F.normalize(embeds, p=2, dim=1)
return embeds.numpy()
@st.cache_data
def get_answers(query, num_answers):
logging.info("Getting answers for {}".format(query))
model, tokenizer = load_model()
index = load_index()
query_embedding = get_embedding(query, model, tokenizer).reshape(-1)
_, answers = index.get_nearest_examples('embedding', query_embedding, num_answers)
answers = pd.DataFrame.from_dict(answers)
logging.info("Succesfully got answers for {}".format(query))
return answers.to_dict('records')
def display_article(article):
with st.container():
href = "https://arxiv.org/abs/{}".format(article['id'])
title = "<h3><a href=\"{}\">{}</a></h3>".format(
href, article['title'])
st.write(title, unsafe_allow_html=True)
st.markdown(article['abstract'])
st.write("---")
def display_answers(query, max_answers=100):
st.write("---")
articles = get_answers(query, max_answers)
for article in articles[:st.session_state['num_articles_to_show']]:
display_article(article)