You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

115 lines
3.5 KiB
Python

import datetime
from utils.files import create_dir
import pandas as pd
import numpy as np
from config import *
import matplotlib.pyplot as plt
def is_in_bypass_list(column_name: str, bypass_list: tuple) -> bool:
for bypass in bypass_list:
if bypass in column_name:
return True
return False
def input_csv_to_df(file_path: str) -> pd.DataFrame:
# 使用pandas的read_csv函数读取CSV文件
df = pd.read_csv(file_path)
return df
def averaging_df(df: pd.DataFrame, column_num: int = None):
numeric_columns = df.select_dtypes(include=[np.number]).columns
max_values = df.max()
if column_num is None:
column_num = 0
for numeric_column in numeric_columns:
if is_in_bypass_list(numeric_column, BYPASS_COLUMNS):
continue
column_num = column_num + 1
for numeric_column in numeric_columns:
if is_in_bypass_list(numeric_column, BYPASS_COLUMNS):
continue
df[numeric_column] = df[numeric_column] / max_values[numeric_column] * column_num
# fix nan
df[numeric_column] = df[numeric_column].fillna(0)
return df, column_num
def iter_df_to_point(df: pd.DataFrame, column_num: int = None):
size = 0
points = []
for index, row in df.iterrows():
x_values = row.values[2:]
y_values = np.linspace(0, len(x_values) - 1, len(x_values))
size = size + 1
points.append({index:(x_values, y_values)})
return points
def generate_one_plot(x_values, y_values, x_y_size: int) -> plt:
yedges = xedges = np.linspace(0, x_y_size, x_y_size)
H = np.zeros((x_y_size, x_y_size))
plt.pcolormesh(xedges, yedges, H) # pcolormeshp()函数用于创建具有非规则矩形网格的伪彩色图
plt.scatter(x_values, y_values, marker=',', s=1)
plt.xlim(0, x_y_size)
plt.ylim(0, x_y_size)
# 326
plt.ylabel('Attributes')
plt.xlabel('Attribute values')
# plt.set_cmap('gnuplot')
plt.set_cmap('BuPu')
# plt.set_cmap('Greys')
plt.axis('on')
return plt
# plt.savefig(os.path.join(figure_save_path, qwe + ".png"), bbox_inches='tight', pad_inches=0) # 分别命名图片
def save_plt(plt: plt, base_path: str, num: int):
plt.savefig(f"{base_path}/{num}.png", bbox_inches='tight', pad_inches=0)
from multiprocessing import Pool, cpu_count
def process(df: pd.DataFrame):
df, size = averaging_df(df)
points = iter_df_to_point(df, size)
base_path = f'./saves/{datetime.datetime.now().strftime("%Y%m%d%H%M%S")}'
create_dir(base_path)
pool = Pool(cpu_count())
results = []
for point_dict in points:
num = list(point_dict.keys())[0]
point = point_dict[num]
result = pool.apply_async(generate_and_save, args=(base_path, point, size, num))
results.append(result)
pool.close()
pool.join()
def generate_and_save(base_path: str, point: tuple, size: int, calculate):
plt = generate_one_plot(point[0], point[1], size)
save_plt(plt, base_path, calculate)
def process_single_threaded(df: pd.DataFrame):
df, size = averaging_df(df)
points = iter_df_to_point(df, size)
base_path = f'./saves/{datetime.datetime.now().strftime("%Y%m%d%H%M%S")}'
create_dir(base_path)
for point_dict in points:
num = list(point_dict.keys())[0]
point = point_dict[num]
size = len(point[0])
generate_and_save(base_path, point, size, num)
# plt.show()
# return df
if __name__ == '__main__':
df = input_csv_to_df(CSV_PATH)
process(df)