import streamlit as st
import tensorflow as tf
import numpy as np
import pandas as pd
import json
from transformers import *
from tqdm import tqdm
from tensorflow.python.client import device_lib
import requests
from bs4 import BeautifulSoup
import time

PATH_t = './checkpoint-7500/'
PATH = './checkpoint-18750/'
SEQ_LEN = 128
tokenizer_t = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
tokenizer = AutoTokenizer.from_pretrained("klue/bert-base")

def create_sentiment_bert():
  # 버트 pretrained 모델 로드
  model = TFAutoModel.from_pretrained(PATH,local_files_only=True)
  # 토큰 인풋, 마스크 인풋, 세그먼트 인풋 정의
  token_inputs = tf.keras.layers.Input((SEQ_LEN,), dtype=tf.int32, name='input_word_ids')
  mask_inputs = tf.keras.layers.Input((SEQ_LEN,), dtype=tf.int32, name='input_masks')
  segment_inputs = tf.keras.layers.Input((SEQ_LEN,), dtype=tf.int32, name='input_segment')
  # 인풋이 [토큰, 마스크, 세그먼트]인 모델 정의
  bert_outputs = model([token_inputs, mask_inputs, segment_inputs])

  bert_outputs = bert_outputs[1]
  sentiment_first = tf.keras.layers.Dense(1, activation='sigmoid', kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))(bert_outputs)
  sentiment_model = tf.keras.Model([token_inputs, mask_inputs, segment_inputs], sentiment_first)

  sentiment_model.compile(loss=tf.keras.losses.BinaryCrossentropy(), metrics = ['accuracy'])
  return sentiment_model

def sentence_convert_data(data):
    global tokenizer
    tokens, masks, segments = [], [], []
    token = tokenizer.encode(data, max_length=SEQ_LEN, truncation=True, padding='max_length')
    
    num_zeros = token.count(0) 
    mask = [1]*(SEQ_LEN-num_zeros) + [0]*num_zeros 
    segment = [0]*SEQ_LEN

    tokens.append(token)
    segments.append(segment)
    masks.append(mask)

    tokens = np.array(tokens)
    masks = np.array(masks)
    segments = np.array(segments)
    return [tokens, masks, segments]

def evaluation_predict(sentence):
    data_x = sentence_convert_data(sentence)
    predict = sentiment_model.predict(data_x)
    predict_value = np.ravel(predict)
    # 0:부정, 1:긍정
    predict_answer = np.round(predict_value,0).item()
    return predict_answer

def get_comments(news_url):
    # oid, aid 추출

    list = news_url.split("/")
    oid = list[-2]
    aid = list[-1]
    if len(aid) > 10:
        aid = aid[:10]
    
    # API URL 구성
    api_url = "https://apis.naver.com/commentBox/cbox/web_naver_list_jsonp.json"
    params = {
        "ticket": "news",
        "templateId": "default_society",
        "pool": "cbox5",
        "lang": "ko",
        "country": "KR",
        "objectId": f"news{oid},{aid}",
        "pageSize": 100,
        "indexSize": 10,
        "page": 1,
        "sort": "FAVORITE" # 'NEW'(최신순), 'FAVORITE'(순공감순)
    }
    
    headers = {
        "User-Agent": "Mozilla/5.0",
        "Referer": news_url
    }
    
    # API 호출 및 데이터 처리
    response = requests.get(api_url, params=params, headers=headers)
    content = response.text.replace("_callback(", "").replace(");", "")
    json_data = json.loads(content)

    response = requests.get(news_url)
    article_soup = BeautifulSoup(response.text, "html.parser")

    # 제목 추출
    title = article_soup.select_one("#ct > div.media_end_head.go_trans > div.media_end_head_title > h2")
    if title is None:
        title = article_soup.select_one("#content > div.end_ct > div > h2")
        
    # 본문 추출
    article = article_soup.select_one("#dic_area")
    if article is None:
        article = article_soup.select_one("#articeBody")
    
    return title.text.strip(), article.text.strip(), processing_data(json_data['result']['commentList'])

def processing_data(comments):
    comment_list = []
    for comment in comments:
        comment_list.append(comment['contents'])
    comment_listR = [x for x in comment_list if x]
    return comment_listR


def main():
    global sentiment_model
    sentiment_model = create_sentiment_bert()
    st.title("댓글 필터링 서비스")
    
    # URL 입력 받기
    value = st.query_params['q']
    if value: 
        url = st.text_input("url을 입력하세요",value=st.query_params['q'])
    else:
        url = st.text_input("url을 입력하세요")

    
    if url:
    title, content, comments = get_comments(url)
        
    # 결과 표시
    st.subheader("기사 제목")
    st.write(title)
    
    st.subheader("본문 내용")
    st.write(content)
    
    st.subheader("댓글")
    for comment in comments:
        if evaluation_predict(comment) == 1:
            st.write(comment)

    
    if st.button("스크랩 시작"):
        if url:
            title, content, comments = get_comments(url)
            
            # 결과 표시
            st.subheader("기사 제목")
            st.write(title)
            
            st.subheader("본문 내용")
            st.write(content)
            
            st.subheader("댓글")
            for comment in comments:
                if evaluation_predict(comment) == 1:
                    st.write(comment)
    return 0

if __name__ == "__main__":
    main()