Spaces:
Runtime error
Runtime error
import streamlit as st | |
from huggingface_hub import HfApi, HfFolder | |
import datasets | |
import pandas as pd | |
import logging | |
import os | |
from transformers import AutoTokenizer, AutoModel | |
import torch | |
import torch.nn.functional as F | |
def login(): | |
if 'logged' not in st.session_state: | |
logging.info("Trying to log in to HF") | |
st.session_state['logged'] = True | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
api = HfApi() | |
api.set_access_token(HF_TOKEN) | |
folder = HfFolder() | |
folder.save_token(HF_TOKEN) | |
logging.info("Succesfully logged") | |
return True | |
else: | |
logging.info("Already logged in") | |
return False | |
def load_model(): | |
logging.info("Trying to load model") | |
tokenizer = AutoTokenizer.from_pretrained( | |
'sentence-transformers/all-MiniLM-L6-v2') | |
model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') | |
logging.info("Model loaded") | |
return model, tokenizer | |
def load_index(): | |
logging.info("Trying to load index") | |
index = datasets.Dataset.load_from_disk("Data/articles.hf") | |
logging.info("Articles dataset loaded") | |
index.load_faiss_index("embedding", "Data/articles.index") | |
logging.info("FAISS index loaded") | |
return index | |
def mean_pooling(model_output, attention_mask): | |
token_embeddings = model_output[0] | |
input_mask_expanded = attention_mask.unsqueeze(-1) \ | |
.expand(token_embeddings.size()).float() | |
return torch.sum(token_embeddings * input_mask_expanded, 1) \ | |
/ torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
def get_embedding(query, model, tokenizer): | |
encoded_input = tokenizer( | |
query, padding=True, truncation=True, return_tensors='pt') | |
with torch.no_grad(): | |
embeds = model(**encoded_input) | |
embeds = mean_pooling(embeds, encoded_input['attention_mask']) | |
embeds = F.normalize(embeds, p=2, dim=1) | |
return embeds.numpy() | |
def get_answers(query, num_answers): | |
logging.info("Getting answers for {}".format(query)) | |
model, tokenizer = load_model() | |
index = load_index() | |
query_embedding = get_embedding(query, model, tokenizer).reshape(-1) | |
_, answers = index.get_nearest_examples('embedding', query_embedding, num_answers) | |
answers = pd.DataFrame.from_dict(answers) | |
logging.info("Succesfully got answers for {}".format(query)) | |
return answers.to_dict('records') | |
def display_article(article): | |
with st.container(): | |
href = "https://arxiv.org/abs/{}".format(article['id']) | |
title = "<h3><a href=\"{}\">{}</a></h3>".format( | |
href, article['title']) | |
st.write(title, unsafe_allow_html=True) | |
st.markdown(article['abstract']) | |
st.write("---") | |
def display_answers(query, max_answers=100): | |
st.write("---") | |
articles = get_answers(query, max_answers) | |
for article in articles[:st.session_state['num_articles_to_show']]: | |
display_article(article) | |