#!/usr/bin/env python3 # -*-coding:utf-8-*- # SPDX-FileCopyrightText: 2021-2025 Espressif Systems (Shanghai) CO LTD # SPDX-License-Identifier: Apache-2.0 # # WARNING: we don't check for Python build-time dependencies until # check_environment() function below. If possible, avoid importing # any external libraries here - put in external script, or import in # their specific function instead. import sys import csv import json import argparse import pandas as pd import numpy as np import serial from os import path from io import StringIO from PyQt5.Qt import * from pyqtgraph import PlotWidget from PyQt5 import QtCore import pyqtgraph as pg from pyqtgraph import ScatterPlotItem from PyQt5.QtCore import pyqtSignal, QThread import threading import time from scipy.optimize import minimize import matplotlib.pyplot as plt from scipy.stats import linregress import statsmodels.api as sm # Reduce displayed waveforms to avoid display freezes CSI_VAID_SUBCARRIER_INTERVAL = 1 csi_vaid_subcarrier_len =0 CSI_DATA_INDEX = 200 # buffer size CSI_DATA_COLUMNS = 490 DATA_COLUMNS_NAMES_C5C6 = ['type', 'id', 'mac', 'rssi', 'rate','noise_floor','fft_gain','agc_gain', 'channel', 'local_timestamp', 'sig_len', 'rx_state', 'len', 'first_word', 'data'] DATA_COLUMNS_NAMES = ['type', 'id', 'mac', 'rssi', 'rate', 'sig_mode', 'mcs', 'bandwidth', 'smoothing', 'not_sounding', 'aggregation', 'stbc', 'fec_coding', 'sgi', 'noise_floor', 'ampdu_cnt', 'channel', 'secondary_channel', 'local_timestamp', 'ant', 'sig_len', 'rx_state', 'len', 'first_word', 'data'] csi_data_array = np.zeros( [CSI_DATA_INDEX, CSI_DATA_COLUMNS], dtype=np.float64) csi_data_phase = np.zeros([CSI_DATA_INDEX, CSI_DATA_COLUMNS], dtype=np.float64) csi_data_complex = np.zeros([CSI_DATA_INDEX, CSI_DATA_COLUMNS], dtype=np.complex64) agc_gain_data = np.zeros([CSI_DATA_INDEX], dtype=np.float64) fft_gain_data = np.zeros([CSI_DATA_INDEX], dtype=np.float64) fft_gains = [] agc_gains = [] class csi_data_graphical_window(QWidget): def __init__(self): super().__init__() self.resize(1280, 900) self.plotWidget_ted = PlotWidget(self) self.plotWidget_ted.setGeometry(QtCore.QRect(0, 0, 640, 300)) self.plotWidget_ted.setYRange(-2*np.pi, 2*np.pi) self.plotWidget_ted.addLegend() self.plotWidget_ted.setTitle('Phase Data - Last Frame') # 添加标题 self.plotWidget_ted.setLabel('left', 'Phase (rad)') # Y轴标签 self.plotWidget_ted.setLabel('bottom', 'Subcarrier Index') # X轴标签 self.csi_amplitude_array = np.abs(csi_data_complex) self.csi_phase_array = np.angle(csi_data_complex) self.curve = self.plotWidget_ted.plot([], name='CSI Row Data', pen='r') self.plotWidget_multi_data = PlotWidget(self) self.plotWidget_multi_data.setGeometry(QtCore.QRect(0, 300, 1280, 300)) self.plotWidget_multi_data.getViewBox().enableAutoRange(axis=pg.ViewBox.YAxis) self.plotWidget_multi_data.addLegend() self.plotWidget_multi_data.setTitle('Subcarrier Amplitude Data') # 添加标题 self.plotWidget_multi_data.setLabel('left', 'Amplitude') # Y轴标签 self.plotWidget_multi_data.setLabel('bottom', 'Time (Cumulative Packet Count)') # X轴标签 self.curve_list = [] agc_curve = self.plotWidget_multi_data.plot( agc_gain_data, name='AGC Gain', pen=[255,255,0]) fft_curve = self.plotWidget_multi_data.plot( fft_gain_data, name='FFT Gain', pen=[255,255,0]) self.curve_list.append(agc_curve) self.curve_list.append(fft_curve) for i in range(CSI_DATA_COLUMNS): curve = self.plotWidget_multi_data.plot( self.csi_amplitude_array[:, i], name=str(i), pen=(255, 255, 255)) self.curve_list.append(curve) self.plotWidget_phase_data = PlotWidget(self) self.plotWidget_phase_data.setGeometry(QtCore.QRect(0, 600, 1280, 300)) self.plotWidget_phase_data.getViewBox().enableAutoRange(axis=pg.ViewBox.YAxis) self.plotWidget_phase_data.addLegend() self.plotWidget_multi_data.setTitle('Subcarrier Phase Data') # 添加标题 self.plotWidget_multi_data.setLabel('left', 'Phase (rad)') # Y轴标签 self.plotWidget_multi_data.setLabel('bottom', 'Time (Cumulative Packet Count)') # X轴标签 self.curve_phase_list = [] for i in range(CSI_DATA_COLUMNS): phase_curve = self.plotWidget_phase_data.plot( np.angle(self.csi_amplitude_array[:, i]), name=str(i), pen=(255, 255, 255)) self.curve_phase_list.append(phase_curve) # IQ 图窗口 self.plotWidget_iq = PlotWidget(self) self.plotWidget_iq.setGeometry(QtCore.QRect(640, 0, 640, 300)) self.plotWidget_iq.setLabel('left', 'Q (Imag)') self.plotWidget_iq.setLabel('bottom', 'I (Real)') self.plotWidget_iq.setTitle('IQ Plot - Last Frame') view_box = self.plotWidget_iq.getViewBox() view_box.setRange(QtCore.QRectF(-30, -30, 60, 60)) # 可以调整范围的大小,保证原点在中间 self.plotWidget_iq.getViewBox().setAspectLocked(True) self.iq_scatter = ScatterPlotItem(size=6) self.plotWidget_iq.addItem(self.iq_scatter) self.iq_colors = [] self.timer = pg.QtCore.QTimer() self.timer.timeout.connect(self.update_data) self.timer.start(100) self.deta_len = 0 def update_curve_colors(self, color_list): self.deta_len = len(color_list) self.iq_colors = color_list self.plotWidget_ted.setXRange(0, self.deta_len//2) for i in range(self.deta_len): self.curve_list[i].setPen(color_list[i]) self.curve_phase_list[i].setPen(color_list[i]) def update_data(self): i = np.real(csi_data_complex[-1, :]) q = np.imag(csi_data_complex[-1, :]) points = [] for idx in range(self.deta_len): if idx < len(self.iq_colors): color = self.iq_colors[idx] else: color = (200, 200, 200) points.append({'pos': (i[idx], q[idx]), 'brush': pg.mkBrush(color)}) self.iq_scatter.setData(points) self.csi_amplitude_array = np.abs(csi_data_complex) self.csi_phase_array = np.angle(csi_data_complex) self.csi_row_data = self.csi_phase_array[-1, :] self.curve.setData(self.csi_row_data) self.curve_list[CSI_DATA_COLUMNS].setData(agc_gain_data) self.curve_list[CSI_DATA_COLUMNS+1].setData(fft_gain_data) for i in range(CSI_DATA_COLUMNS): self.curve_list[i].setData(self.csi_amplitude_array[:, i]) self.curve_phase_list[i].setData(self.csi_phase_array[:, i]) def generate_subcarrier_colors(red_range, green_range, yellow_range, total_num,interval=1): colors = [] for i in range(total_num): if red_range and red_range[0] <= i <= red_range[1]: intensity = int(255 * (i - red_range[0]) / (red_range[1] - red_range[0])) colors.append((intensity, 0, 0)) elif green_range and green_range[0] <= i <= green_range[1]: intensity = int(255 * (i - green_range[0]) / (green_range[1] - green_range[0])) colors.append((0, intensity, 0)) elif yellow_range and yellow_range[0] <= i <= yellow_range[1]: intensity = int(255 * (i - yellow_range[0]) / (yellow_range[1] - yellow_range[0])) colors.append((0, intensity, intensity)) else: colors.append((200, 200, 200)) return colors def csi_data_read_parse(port: str, csv_writer, log_file_fd,callback=None): global fft_gains, agc_gains set = serial.Serial(port=port, baudrate=921600,bytesize=8, parity='N', stopbits=1) count =0 if set.isOpen(): print('open success') else: print('open failed') return while True: strings = str(set.readline()) if not strings: break strings = strings.lstrip('b\'').rstrip('\\r\\n\'') index = strings.find('CSI_DATA') if index == -1: log_file_fd.write(strings + '\n') log_file_fd.flush() continue csv_reader = csv.reader(StringIO(strings)) csi_data = next(csv_reader) csi_data_len = int (csi_data[-3]) if len(csi_data) != len(DATA_COLUMNS_NAMES) and len(csi_data) != len(DATA_COLUMNS_NAMES_C5C6): print('element number is not equal',len(csi_data),len(DATA_COLUMNS_NAMES) ) # print(csi_data) log_file_fd.write('element number is not equal\n') log_file_fd.write(strings + '\n') log_file_fd.flush() continue try: csi_raw_data = json.loads(csi_data[-1]) except json.JSONDecodeError: print('data is incomplete') log_file_fd.write('data is incomplete\n') log_file_fd.write(strings + '\n') log_file_fd.flush() continue if csi_data_len != len(csi_raw_data): print('csi_data_len is not equal',csi_data_len,len(csi_raw_data)) log_file_fd.write('csi_data_len is not equal\n') log_file_fd.write(strings + '\n') log_file_fd.flush() continue fft_gain = int(csi_data[6]) agc_gain = int(csi_data[7]) fft_gains.append(fft_gain) agc_gains.append(agc_gain) csv_writer.writerow(csi_data) # Rotate data to the left # csi_data_array[:-1] = csi_data_array[1:] # csi_data_phase[:-1] = csi_data_phase[1:] csi_data_complex[:-1] = csi_data_complex[1:] agc_gain_data[:-1] = agc_gain_data[1:] fft_gain_data[:-1] = fft_gain_data[1:] agc_gain_data[-1] = agc_gain fft_gain_data[-1] = fft_gain if count ==0: count = 1 print('none',csi_data_len) if csi_data_len == 106: colors = generate_subcarrier_colors((0,25), (27,53), None, len(csi_raw_data)) elif csi_data_len == 114: colors = generate_subcarrier_colors((0,27), (29,56), None, len(csi_raw_data)) elif csi_data_len == 52: colors = generate_subcarrier_colors((0,12), (13,26), None, len(csi_raw_data)) elif csi_data_len == 234 : colors = generate_subcarrier_colors((0,28), (29,56), (60,116), len(csi_raw_data)) elif csi_data_len == 228 : colors = generate_subcarrier_colors((0,28), (29,57), (57,113), len(csi_raw_data)) elif csi_data_len == 490 : colors = generate_subcarrier_colors((0,61), (62,122), (123,245), len(csi_raw_data)) elif csi_data_len == 128 : colors = generate_subcarrier_colors((0,31), (32,63), None, len(csi_raw_data)) elif csi_data_len == 256 : colors = generate_subcarrier_colors((0,32), (32,63), (64,128), len(csi_raw_data)) elif csi_data_len == 512 : colors = generate_subcarrier_colors((0,63), (64,127), (128,256), len(csi_raw_data)) elif csi_data_len == 384 : colors = generate_subcarrier_colors((0,63), (64,127), (128,192), len(csi_raw_data)) elif csi_data_len > 0 and csi_data_len <= 612: raw_len = len(csi_raw_data) colors = generate_subcarrier_colors((0,raw_len//2), (raw_len//2+1,raw_len-1), None, raw_len) callback(colors) for i in range(csi_data_len // 2): csi_data_complex[-1][i] = complex(csi_raw_data[i * 2 + 1], csi_raw_data[i * 2]) set.close() return class SubThread (QThread): data_ready = pyqtSignal(object) def __init__(self, serial_port, save_file_name, log_file_name): super().__init__() self.serial_port = serial_port save_file_fd = open(save_file_name, 'w') self.log_file_fd = open(log_file_name, 'w') self.csv_writer = csv.writer(save_file_fd) self.csv_writer.writerow(DATA_COLUMNS_NAMES) def run(self): csi_data_read_parse(self.serial_port, self.csv_writer, self.log_file_fd,callback=self.data_ready.emit) def __del__(self): self.wait() self.log_file_fd.close() if __name__ == '__main__': if sys.version_info < (3, 6): print(' Python version should >= 3.6') exit() parser = argparse.ArgumentParser( description='Read CSI data from serial port and display it graphically') parser.add_argument('-p', '--port', dest='port', action='store', required=True, help='Serial port number of csv_recv device') parser.add_argument('-s', '--store', dest='store_file', action='store', default='./csi_data.csv', help='Save the data printed by the serial port to a file') parser.add_argument('-l', '--log', dest='log_file', action='store', default='./csi_data_log.txt', help='Save other serial data the bad CSI data to a log file') args = parser.parse_args() serial_port = args.port file_name = args.store_file log_file_name = args.log_file app = QApplication(sys.argv) subthread = SubThread(serial_port, file_name, log_file_name) window = csi_data_graphical_window() subthread.data_ready.connect(window.update_curve_colors) subthread.start() window.show() sys.exit(app.exec())