from sols_06 import *
from sqlalchemy import text, create_engine
from sqlalchemy.orm import Session
import pandas as pd

class DataBase:
    def __init__(self, loc: str, db_type: str = "sqlite") -> None:
        """Initialize the class and connect to the database"""
        self.loc = loc
        self.db_type = db_type
        self.engine = create_engine(f'{self.db_type}:///{self.loc}')
    def query(self, q: str) -> pd.DataFrame:
        """Run a query against the database and return a DataFrame"""
        with Session(self.engine) as session:
            df = pd.read_sql(q, session.bind)
        return(df)
    def execute(self, q: str) -> None:
        """Execute statement on the database"""
        with self.engine.connect() as conn:
            conn.execute(text(q))

# set the path to the auctions database you're using
path = '/Users/hlukas/git/teaching/solutions/auctions.db'
auctions = DataBase(path)

std_out = auctions.query(
    """
    select itemid, sqrt(sum(demeaned_bidamount) / (count(demeaned_bidamount) - 1)) as std from (
    select itemid, power(bidamount - avg(bidamount) over (partition by itemid), 2) as demeaned_bidamount
    from bids) as a
    group by itemid
    having count(demeaned_bidamount) > 1
    """
)

bidder_spend_frac_out = auctions.query(
    """
    select biddername, sum(bidamount) as total_bids, sum(bidamount * high_bid) as total_spend, 
    sum(bidamount * high_bid) / sum(bidamount) as spend_frac
    from (
    select itemid, biddername, bidamount, (bidamount = max(bidamount) over (partition by itemid)) as high_bid
    from 
    (select * from bids
    group by biddername, itemid
    having bidamount = max(bidamount)) as a) as b
    group by biddername
    """
)

win_perc_by_timestamp_out = auctions.query(
    """
    select cast(timestamp_bin as int) as timestamp_bin, 
    cast(sum(high_bid) as float)/ cast(count(*) as float) as win_perc from
    (select b.itemid, b.biddername, 
    (julianday(i.endtime) - julianday(b.bidtime))/(julianday(i.endtime) - julianday(i.starttime)) as timestamp,
    ceiling(((julianday(i.endtime) - julianday(b.bidtime))/(julianday(i.endtime) - julianday(i.starttime)))*10) as timestamp_bin,
    cast(b.bidamount = max(b.bidamount) over (partition by b.itemid) as int) as high_bid
    from bids as b
    left join items as i
    on b.itemid = i.itemid) as a
    where high_bid is not null
    group by timestamp_bin
    """
)

def test_github():
    url = github()
    repo_url = re.search('github\\.com/(.+)/blob', url).group(1)
    req = requests.get(f'https://api.github.com/repos/{repo_url}/stats/participation')
    assert req.json()['all'][-1] > 0

def test_types():
    assert isinstance(std(), str) and isinstance(bidder_spend_frac(), str) and isinstance(win_perc_by_timestamp(), str)

def test_std_rows():
    assert len(auctions.query(std())) == len(std_out)

def test_std_sds():
    assert sum(auctions.query(std())['std']) == sum(std_out['std'])

def test_bidder_spend_frac_rows():
    assert len(auctions.query(bidder_spend_frac())) == len(bidder_spend_frac_out)

def test_bidder_spend_frac_fracs():
    assert sum(auctions.query(bidder_spend_frac())['spend_frac']) == sum(bidder_spend_frac_out['spend_frac'])

def test_win_perc_by_timestamp_rows():
    assert len(auctions.query(win_perc_by_timestamp())) == len(win_perc_by_timestamp_out)

def test_win_perc_by_timestamp_timestamps():
    assert sum(auctions.query(win_perc_by_timestamp())['win_perc']) == sum(win_perc_by_timestamp_out['win_perc'])
