#!/usr/bin/env python3 # -*-coding:utf-8-*- # SPDX-FileCopyrightText: 2025 Espressif Systems (Shanghai) CO LTD # SPDX-License-Identifier: Apache-2.0 # # WARNING: we don't check for Python build-time dependencies until # check_environment() function below. If possible, avoid importing # any external libraries here - put in external script, or import in # their specific function instead. import sys import csv import json import argparse import pandas as pd import numpy as np import serial from os import path from io import StringIO from PyQt5.Qt import * from pyqtgraph import PlotWidget from PyQt5 import QtCore import pyqtgraph as pg from pyqtgraph import ScatterPlotItem from PyQt5.QtCore import pyqtSignal, QThread import threading import time from scipy.optimize import minimize import matplotlib.pyplot as plt from scipy.stats import linregress import statsmodels.api as sm # Reduce displayed waveforms to avoid display freezes CSI_VAID_SUBCARRIER_INTERVAL = 1 csi_vaid_subcarrier_len =0 CSI_DATA_INDEX = 200 # buffer size CSI_DATA_COLUMNS = 490 DATA_COLUMNS_NAMES_C5C6 = ['type', 'id', 'mac', 'rssi', 'rate','noise_floor','fft_gain','agc_gain', 'channel', 'local_timestamp', 'sig_len', 'rx_state', 'len', 'first_word', 'data'] DATA_COLUMNS_NAMES_NEW = ['type', 'mac','len', 'first_word', 'data'] DATA_COLUMNS_NAMES = ['type', 'id', 'mac', 'rssi', 'rate', 'sig_mode', 'mcs', 'bandwidth', 'smoothing', 'not_sounding', 'aggregation', 'stbc', 'fec_coding', 'sgi', 'noise_floor', 'ampdu_cnt', 'channel', 'secondary_channel', 'local_timestamp', 'ant', 'sig_len', 'rx_state', 'len', 'first_word', 'data'] csi_data_array = np.zeros( [CSI_DATA_INDEX, CSI_DATA_COLUMNS], dtype=np.float64) csi_data_phase = np.zeros([CSI_DATA_INDEX, CSI_DATA_COLUMNS], dtype=np.float64) csi_data_complex = np.zeros([CSI_DATA_INDEX, CSI_DATA_COLUMNS], dtype=np.complex64) agc_gain_data = np.zeros([CSI_DATA_INDEX], dtype=np.float64) fft_gain_data = np.zeros([CSI_DATA_INDEX], dtype=np.float64) fft_gains = [] agc_gains = [] # RAW_DATA 独立数据数组 RAW_DATA_COLUMNS = 612 # RAW_DATA 最大长度 raw_data_complex = np.zeros([CSI_DATA_INDEX, RAW_DATA_COLUMNS], dtype=np.complex64) class csi_data_graphical_window(QWidget): def __init__(self): super().__init__() self.resize(1280, 900) self.plotWidget_ted = PlotWidget(self) self.plotWidget_ted.setGeometry(QtCore.QRect(0, 0, 640, 300)) self.plotWidget_ted.setYRange(-2*np.pi, 2*np.pi) self.plotWidget_ted.addLegend() self.plotWidget_ted.setTitle('Phase Data - Last Frame') # 添加标题 self.plotWidget_ted.setLabel('left', 'Phase (rad)') # Y轴标签 self.plotWidget_ted.setLabel('bottom', 'Subcarrier Index') # X轴标签 self.csi_amplitude_array = np.abs(csi_data_complex) self.csi_phase_array = np.angle(csi_data_complex) self.curve = self.plotWidget_ted.plot([], name='CSI Phase', pen='r') # RAW_DATA 相位曲线(蓝色) self.raw_amplitude_array = np.abs(raw_data_complex) self.raw_phase_array = np.angle(raw_data_complex) self.curve_raw = self.plotWidget_ted.plot([], name='RAW Phase', pen='b') self.plotWidget_multi_data = PlotWidget(self) self.plotWidget_multi_data.setGeometry(QtCore.QRect(0, 300, 1280, 300)) self.plotWidget_multi_data.getViewBox().enableAutoRange(axis=pg.ViewBox.YAxis) self.plotWidget_multi_data.addLegend() self.plotWidget_multi_data.setTitle('Subcarrier Amplitude Data') # 添加标题 self.plotWidget_multi_data.setLabel('left', 'Amplitude') # Y轴标签 self.plotWidget_multi_data.setLabel('bottom', 'Time (Cumulative Packet Count)') # X轴标签 self.curve_list = [] agc_curve = self.plotWidget_multi_data.plot( agc_gain_data, name='AGC Gain', pen=[255,255,0]) fft_curve = self.plotWidget_multi_data.plot( fft_gain_data, name='FFT Gain', pen=[255,255,0]) self.curve_list.append(agc_curve) self.curve_list.append(fft_curve) for i in range(CSI_DATA_COLUMNS): curve = self.plotWidget_multi_data.plot( self.csi_amplitude_array[:, i], name=str(i), pen=(255, 255, 255)) self.curve_list.append(curve) self.plotWidget_phase_data = PlotWidget(self) self.plotWidget_phase_data.setGeometry(QtCore.QRect(0, 600, 1280, 300)) self.plotWidget_phase_data.getViewBox().enableAutoRange(axis=pg.ViewBox.YAxis) self.plotWidget_phase_data.addLegend() self.plotWidget_multi_data.setTitle('Subcarrier Phase Data') # 添加标题 self.plotWidget_multi_data.setLabel('left', 'Phase (rad)') # Y轴标签 self.plotWidget_multi_data.setLabel('bottom', 'Time (Cumulative Packet Count)') # X轴标签 self.curve_phase_list = [] for i in range(CSI_DATA_COLUMNS): phase_curve = self.plotWidget_phase_data.plot( np.angle(self.csi_amplitude_array[:, i]), name=str(i), pen=(255, 255, 255)) self.curve_phase_list.append(phase_curve) # IQ 图窗口 self.plotWidget_iq = PlotWidget(self) self.plotWidget_iq.setGeometry(QtCore.QRect(640, 0, 640, 300)) self.plotWidget_iq.setLabel('left', 'Q (Imag)') self.plotWidget_iq.setLabel('bottom', 'I (Real)') self.plotWidget_iq.setTitle('IQ Plot - Last Frame') view_box = self.plotWidget_iq.getViewBox() view_box.setRange(QtCore.QRectF(-30, -30, 60, 60)) # 可以调整范围的大小,保证原点在中间 self.plotWidget_iq.getViewBox().setAspectLocked(True) self.iq_scatter = ScatterPlotItem(size=6) self.plotWidget_iq.addItem(self.iq_scatter) self.iq_colors = [] self.timer = pg.QtCore.QTimer() self.timer.timeout.connect(self.update_data) self.timer.start(100) self.deta_len = 0 def update_curve_colors(self, color_list): self.deta_len = len(color_list) self.iq_colors = color_list self.plotWidget_ted.setXRange(0, self.deta_len//2) for i in range(self.deta_len): self.curve_list[i].setPen(color_list[i]) self.curve_phase_list[i].setPen(color_list[i]) def update_data(self): i = np.real(csi_data_complex[-1, :]) q = np.imag(csi_data_complex[-1, :]) # points = [] # for idx in range(self.deta_len): # if idx < len(self.iq_colors): # color = self.iq_colors[idx] # else: # color = (200, 200, 200) # points.append({'pos': (i[idx], q[idx]), 'brush': pg.mkBrush(color)}) # self.iq_scatter.setData(points) # CSI_DATA 相位更新 self.csi_amplitude_array = np.abs(csi_data_complex) self.csi_phase_array = np.angle(csi_data_complex) self.csi_row_data = self.csi_phase_array[-1, :] self.csi_row_data = np.unwrap(self.csi_row_data) self.curve.setData(self.csi_row_data) # RAW_DATA 相位更新 self.raw_amplitude_array = np.abs(raw_data_complex) self.raw_phase_array = np.angle(raw_data_complex) self.raw_row_data = self.raw_phase_array[-1, :] self.raw_row_data = np.unwrap(self.raw_row_data) self.curve_raw.setData(self.raw_row_data) # self.curve_list[CSI_DATA_COLUMNS].setData(agc_gain_data) # self.curve_list[CSI_DATA_COLUMNS+1].setData(fft_gain_data) for i in range(CSI_DATA_COLUMNS): self.curve_list[i].setData(self.csi_amplitude_array[:, i]) # self.curve_phase_list[i].setData(self.csi_phase_array[:, i]) def generate_subcarrier_colors(red_range, green_range, yellow_range, total_num,interval=1): colors = [] for i in range(total_num): if red_range and red_range[0] <= i <= red_range[1]: intensity = int(255 * (i - red_range[0]) / (red_range[1] - red_range[0])) colors.append((intensity, 0, 0)) elif green_range and green_range[0] <= i <= green_range[1]: intensity = int(255 * (i - green_range[0]) / (green_range[1] - green_range[0])) colors.append((0, intensity, 0)) elif yellow_range and yellow_range[0] <= i <= yellow_range[1]: intensity = int(255 * (i - yellow_range[0]) / (yellow_range[1] - yellow_range[0])) colors.append((0, intensity, intensity)) else: colors.append((200, 200, 200)) return colors def csi_data_read_parse(port: str, csv_writer, log_file_fd,callback=None): global fft_gains, agc_gains set = serial.Serial(port=port, baudrate=921600,bytesize=8, parity='N', stopbits=1) csi_count = 0 # CSI_DATA 计数器 raw_count = 0 # RAW_DATA 计数器 if set.isOpen(): print('open success') else: print('open failed') return while True: strings = str(set.readline()) if not strings: break strings = strings.lstrip('b\'').rstrip('\\r\\n\'') index = strings.find('CSI_DATA') raw_index = strings.find('RAW_DATA') if index == -1 and raw_index == -1: log_file_fd.write(strings + '\n') log_file_fd.flush() continue # RAW_DATA 单独处理 if raw_index != -1: csv_reader = csv.reader(StringIO(strings)) csi_data = next(csv_reader) csi_data_len = int(csi_data[-3]) try: csi_raw_data = json.loads(csi_data[-1]) except json.JSONDecodeError: print('RAW_DATA is incomplete') log_file_fd.write('RAW_DATA is incomplete\n') log_file_fd.write(strings + '\n') log_file_fd.flush() continue if csi_data_len != len(csi_raw_data): print('RAW_DATA csi_data_len is not equal',csi_data_len,len(csi_raw_data)) log_file_fd.write('RAW_DATA csi_data_len is not equal\n') log_file_fd.write(strings + '\n') log_file_fd.flush() continue # RAW_DATA 使用独立的数据数组 raw_data_complex[:-1] = raw_data_complex[1:] # 清空当前行 raw_data_complex[-1, :] = 0 if raw_count == 0: raw_count = 1 print('RAW_DATA detected, length:', csi_data_len) # RAW_DATA 不需要初始化颜色,只是接收数据显示相位 # RAW_DATA 转换为复数存入独立数组 for i in range(csi_data_len // 2): raw_data_complex[-1][i] = complex(csi_raw_data[i * 2 + 1], csi_raw_data[i * 2]) continue # CSI_DATA 处理 csv_reader = csv.reader(StringIO(strings)) csi_data = next(csv_reader) csi_data_len = int (csi_data[-3]) if len(csi_data) != len(DATA_COLUMNS_NAMES) and len(csi_data) != len(DATA_COLUMNS_NAMES_C5C6 ) and len(csi_data) != len(DATA_COLUMNS_NAMES_NEW): print('element number is not equal',len(csi_data),len(DATA_COLUMNS_NAMES) ) print(strings) log_file_fd.write('element number is not equal\n') log_file_fd.write(strings + '\n') log_file_fd.flush() continue try: csi_raw_data = json.loads(csi_data[-1]) except json.JSONDecodeError: print('data is incomplete') log_file_fd.write('data is incomplete\n') log_file_fd.write(strings + '\n') log_file_fd.flush() continue if csi_data_len != len(csi_raw_data): print('csi_data_len is not equal',csi_data_len,len(csi_raw_data)) log_file_fd.write('csi_data_len is not equal\n') log_file_fd.write(strings + '\n') log_file_fd.flush() continue fft_gain = 0 # int(csi_data[6]) agc_gain = 0 # int(csi_data[7]) fft_gains.append(fft_gain) agc_gains.append(agc_gain) csv_writer.writerow(csi_data) # Rotate data to the left # csi_data_array[:-1] = csi_data_array[1:] # csi_data_phase[:-1] = csi_data_phase[1:] csi_data_complex[:-1] = csi_data_complex[1:] agc_gain_data[:-1] = agc_gain_data[1:] fft_gain_data[:-1] = fft_gain_data[1:] agc_gain_data[-1] = agc_gain fft_gain_data[-1] = fft_gain if csi_count == 0: csi_count = 1 print('CSI_DATA detected, length:',csi_data_len) if csi_data_len == 106: colors = generate_subcarrier_colors((0,25), (27,53), None, len(csi_raw_data)) elif csi_data_len == 114: colors = generate_subcarrier_colors((0,27), (29,56), None, len(csi_raw_data)) elif csi_data_len == 52: colors = generate_subcarrier_colors((0,12), (13,26), None, len(csi_raw_data)) elif csi_data_len == 234 : colors = generate_subcarrier_colors((0,28), (29,56), (60,116), len(csi_raw_data)) elif csi_data_len == 228 : colors = generate_subcarrier_colors((0,28), (29,56), (57,114), len(csi_raw_data)) elif csi_data_len == 328 : colors = generate_subcarrier_colors((0,164), None, None, len(csi_raw_data)) elif csi_data_len == 490 : colors = generate_subcarrier_colors((0,61), (62,122), (123,245), len(csi_raw_data)) elif csi_data_len == 128 : colors = generate_subcarrier_colors((0,31), (32,63), None, len(csi_raw_data)) elif csi_data_len == 256 : colors = generate_subcarrier_colors((0,32), (32,63), (64,128), len(csi_raw_data)) elif csi_data_len == 512 : colors = generate_subcarrier_colors((0,63), (64,127), (128,256), len(csi_raw_data)) elif csi_data_len == 384 : colors = generate_subcarrier_colors((0,63), (64,127), (128,192), len(csi_raw_data)) else: print('Please add more color schemes.') csi_count = 0 continue callback(colors) for i in range(csi_data_len // 2): csi_data_complex[-1][i] = complex(csi_raw_data[i * 2 + 1], csi_raw_data[i * 2]) set.close() return class SubThread (QThread): data_ready = pyqtSignal(object) def __init__(self, serial_port, save_file_name, log_file_name): super().__init__() self.serial_port = serial_port save_file_fd = open(save_file_name, 'w') self.log_file_fd = open(log_file_name, 'w') self.csv_writer = csv.writer(save_file_fd) self.csv_writer.writerow(DATA_COLUMNS_NAMES) def run(self): csi_data_read_parse(self.serial_port, self.csv_writer, self.log_file_fd,callback=self.data_ready.emit) def __del__(self): self.wait() self.log_file_fd.close() if __name__ == '__main__': if sys.version_info < (3, 6): print(' Python version should >= 3.6') exit() parser = argparse.ArgumentParser( description='Read CSI data from serial port and display it graphically') parser.add_argument('-p', '--port', dest='port', action='store', required=True, help='Serial port number of csv_recv device') parser.add_argument('-s', '--store', dest='store_file', action='store', default='./csi_data.csv', help='Save the data printed by the serial port to a file') parser.add_argument('-l', '--log', dest='log_file', action='store', default='./csi_data_log.txt', help='Save other serial data the bad CSI data to a log file') args = parser.parse_args() serial_port = args.port file_name = args.store_file log_file_name = args.log_file app = QApplication(sys.argv) subthread = SubThread(serial_port, file_name, log_file_name) window = csi_data_graphical_window() subthread.data_ready.connect(window.update_curve_colors) subthread.start() window.show() sys.exit(app.exec())