import streamlit as st from huggingface_hub import HfApi, HfFolder import datasets import pandas as pd import logging import os from transformers import AutoTokenizer, AutoModel import torch import torch.nn.functional as F @st.cache_data def login(): if 'logged' not in st.session_state: logging.info("Trying to log in to HF") st.session_state['logged'] = True HF_TOKEN = os.environ.get("HF_TOKEN") api = HfApi() api.set_access_token(HF_TOKEN) folder = HfFolder() folder.save_token(HF_TOKEN) logging.info("Succesfully logged") return True else: logging.info("Already logged in") return False @st.cache_resource def load_model(): logging.info("Trying to load model") tokenizer = AutoTokenizer.from_pretrained( 'sentence-transformers/all-MiniLM-L6-v2') model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') logging.info("Model loaded") return model, tokenizer @st.cache_resource def load_index(): logging.info("Trying to load index") index = datasets.Dataset.load_from_disk("Data/articles.hf") logging.info("Articles dataset loaded") index.load_faiss_index("embedding", "Data/articles.index") logging.info("FAISS index loaded") return index def mean_pooling(model_output, attention_mask): token_embeddings = model_output[0] input_mask_expanded = attention_mask.unsqueeze(-1) \ .expand(token_embeddings.size()).float() return torch.sum(token_embeddings * input_mask_expanded, 1) \ / torch.clamp(input_mask_expanded.sum(1), min=1e-9) def get_embedding(query, model, tokenizer): encoded_input = tokenizer( query, padding=True, truncation=True, return_tensors='pt') with torch.no_grad(): embeds = model(**encoded_input) embeds = mean_pooling(embeds, encoded_input['attention_mask']) embeds = F.normalize(embeds, p=2, dim=1) return embeds.numpy() @st.cache_data def get_answers(query, num_answers): logging.info("Getting answers for {}".format(query)) model, tokenizer = load_model() index = load_index() query_embedding = get_embedding(query, model, tokenizer).reshape(-1) _, answers = index.get_nearest_examples('embedding', query_embedding, num_answers) answers = pd.DataFrame.from_dict(answers) logging.info("Succesfully got answers for {}".format(query)) return answers.to_dict('records') def display_article(article): with st.container(): href = "https://arxiv.org/abs/{}".format(article['id']) title = "

{}

".format( href, article['title']) st.write(title, unsafe_allow_html=True) st.markdown(article['abstract']) st.write("---") def display_answers(query, max_answers=100): st.write("---") articles = get_answers(query, max_answers) for article in articles[:st.session_state['num_articles_to_show']]: display_article(article)