# coding=utf-8
"""
根据数据表生成对应的实体
"""
import os
import traceback

import pymysql

import os

#from tools.Mysqls import mysqls

from stock_sql.sql_env import stock_db

datatable = 'nrvsp_dev'
tables = []

fpath = os.path.abspath('..')  # 获取当前工作的父目录 ！注意是父目录路径
_dir = os.path.join(fpath, "infos")
if not os.path.exists(_dir):
    os.mkdir(_dir)


def get_tables(sqltype, conn, tableName):
    '''
    获取数据库中所有的表
    '''
    try:
        if (sqltype == "mysql"):
            cur = conn.cursor()
            sqlall = "select table_name  from information_schema.tables where table_schema='" + datatable + "'"
            cur.execute(sqlall)
            allTables = cur.fetchall()
            for v in allTables:
                tables.append(v[0])
            cur.close()

            cur = conn.cursor()
            # COLUMN_NAME'字段名称',COLUMN_COMMENT字段备注,COLUMN_TYPE字段类型,IS_NULLABLE是否为空,COLUMN_KEY索引类型
            sql = "SELECT EXTRA,COLUMN_NAME,COLUMN_COMMENT,COLUMN_TYPE,IS_NULLABLE,COLUMN_KEY FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = '" + datatable + "' AND table_name = '" + tableName + "'"
            # print(sql)
            cur.execute(sql)
            write_to_file(cur, tableName)
            cur.close()
            print(tableName + '完成表到实体的转换')
        else:
            cur = conn.cursor()
            sql = "select t.默认值,t.列名,d.COMMENT$,t.类型,t.可为空,t.自增字段  from (SELECT S.NAME AS 表名,C.NAME AS 列名, C.TYPE$ AS 类型,C.LENGTH$ AS 长度,C.SCALE AS 标度,C.INFO2 AS 自增字段,C.NULLABLE$  AS 可为空,C.DEFVAL AS 默认值 FROM SYSOBJECTS S,SYSCOLUMNS C WHERE S.ID = C.ID AND S.SCHID = SF_GET_SCHEMA_ID_BY_NAME('NRVSP_DEV') AND SUBTYPE$ = 'UTAB' AND S.NAME = '" + tableName + "'" + " ORDER BY S.NAME, C.COLID) as t left join (select COLNAME,COMMENT$ from SYSCOLUMNCOMMENTS where schname = 'NRVSP_DEV' and tvname ='" + tableName + "'" + ") d on t.列名= d.COLNAME"
            cur.execute(sql)
            print(sql)
            write_to_file(cur, tableName)
            cur.close()
            print(tableName + '完成表到实体的转换')
    except:
        print('异常=>' + traceback.print_exc())
    finally:
        conn.close()


def write_to_file(row, tb):
    result = row.fetchall()
    # print(result)
    path = _dir + r'/' + tb.lower() + '.py'

    f = open(path, 'w', encoding='utf-8')
    f.write('# coding=utf-8')
    f.write('\n')
    f.write('class ')
    f.write(tb.title().replace('_', "").lower())
    f.write('():')
    f.write('\n')
    f.write('\n')
    f.write("#     __tablename__ = '" + tb.lower())
    f.write('\n')
    f.write("# 字段解释" + str(result))
    f.write('\n')
    no_pri_results = []
    comments = []
    inits = '    def __init__(self'
    # 不自增就添加属性

    for v in result:
        if v[0] != 'auto_increment' or v[0] == 0:
            inits = inits + "," + v[1].lower()
            no_pri_results.append(v)
            comments.append(v[1].lower())
    f.write(inits + "):")
    f.write('\n')
    for no_pri_result in no_pri_results:
        f.write('       #注释' + str(no_pri_result[2]) + "---" + str(no_pri_result[3]))
        f.write('\n')
        f.write('       self.' + no_pri_result[1].lower() + " = " + no_pri_result[1].lower())
        f.write('\n')
    f.write('def get_insert_sql():')
    f.write('\n')
    inssql = '    return "INSERT INTO ' + tb.lower() + ' ('
    values = ' VALUES ('
    values_list = '    # return {'
    values_info = '    # return ' + tb.title().replace('_', "").lower() + "("
    for i in range(len(comments)):
        if i == len(comments) - 1:
            inssql = inssql + comments[i] + ')'
            values = values + '%s' + ');"'
            values_list = values_list + '"' + comments[i] + '":' + comments[i] + '}'
            values_info = values_info + 'dicts' + '["' + comments[i] + '"]' + ')'
        else:
            inssql = inssql + comments[i] + ','
            values = values + '%s' + ','
            values_list = values_list + '"' + comments[i] + '":' + comments[i] + ','
            values_info = values_info + 'dicts' + '["' + comments[i] + '"]' + ','
    inssql = inssql + values
    f.write(inssql)
    f.write('\n')
    f.write('# def get_data_dict(company):')
    f.write('\n')
    for comment in comments:
        f.write('    #' + comment + " = ")
        f.write('\n')
    f.write(values_list)
    f.write('\n')
    f.write('# def get_info(company):')
    f.write('\n')
    f.write('    #dicts = get_data_dict(company)')
    f.write('\n')
    f.write(values_info)

    '''   -----------------------写入文件--------------------------------------   #endregion '''
    f.close()


if __name__ == "__main__":

    host = '1.1.1.1'
    port = 1
    user = '1'
    password = "1"
    db = 'test'
    # 需生成实体类的数据表
    list_tables = ['stock_income']
    #my_sql = stock_db
    for tableName in list_tables:
        con = stock_db
        get_tables("mysql", con, tableName)
