diff --git a/10911831_3.py b/10911831_3.py
new file mode 100644
index 0000000..2c1d9f5
--- /dev/null
+++ b/10911831_3.py
@@ -0,0 +1,49 @@
+# -*- coding: UTF-8 -*-
+import wave
+import numpy as np
+import matplotlib.pyplot as plt
+
+# 打开wav文件 ,open返回一个的是一个Wave_read类的实例,通过调用它的方法读取WAV文件的格式和数据。
+f = wave.open(r"D:\CloudMusic\ss/000005.wav","rb")
+# 读取格式信息
+# 一次性返回所有的WAV文件的格式信息,它返回的是一个组元(tuple):声道数, 量化位数(byte单位), 采
+# 样频率, 采样点数, 压缩类型, 压缩类型的描述。wave模块只支持非压缩的数据,因此可以忽略最后两个信息
+params = f.getparams()
+[nchannels, sampwidth, framerate, nframes] = params[:4]
+# 读取波形数据
+# 读取声音数据,传递一个参数指定需要读取的长度(以取样点为单位)
+str_data = f.readframes(nframes)
+f.close()
+# 将波形数据转换成数组
+# 需要根据声道数和量化单位,将读取的二进制数据转换为一个可以计算的数组
+wave_data = np.fromstring(str_data,dtype = np.short)
+# 将wave_data数组改为2列,行数自动匹配。在修改shape的属性时,需使得数组的总长度不变。
+wave_data.shape = -1,2
+# 转置数据
+wave_data = wave_data.T
+# 通过取样点数和取样频率计算出每个取样的时间。
+time=np.arange(0,nframes/2)/framerate
+# print(params)
+plt.figure(1)
+# time 也是一个数组,与wave_data[0]或wave_data[1]配对形成系列点坐标
+plt.subplot(211)
+plt.plot(time,wave_data[0])
+plt.xlabel("time/s")
+plt.title('Wave')
+
+
+N=44100
+start=0
+# 开始采样位置
+df = framerate/(N-1)
+# 分辨率
+freq = [df*n for n in range(0,N)]
+# N个元素
+wave_data2=wave_data[0][start:start+N]
+c=np.fft.fft(wave_data2)*2/N
+# 常规显示采样频率一半的频谱
+plt.subplot(212)
+plt.plot(freq[:round(len(freq)/2)],abs(c[:round(len(c)/2)]),'r')
+plt.title('Freq')
+plt.xlabel("Freq/Hz")
+plt.show()
diff --git a/AVOC.py b/AVOC.py
new file mode 100644
index 0000000..621b707
--- /dev/null
+++ b/AVOC.py
@@ -0,0 +1,149 @@
+from lxml import etree
+
+
+class GEN_Annotations:
+ def __init__(self, filename, imgpath):
+ self.root = etree.Element("annotation")
+
+ child1 = etree.SubElement(self.root, "folder")
+ child1.text = "ss"
+
+ child2 = etree.SubElement(self.root, "filename")
+ child2.text = filename
+
+ child3 = etree.SubElement(self.root,"path")
+ child3.text = imgpath
+
+ child4 = etree.SubElement(self.root, "source")
+
+ # child4 = etree.SubElement(child3, "annotation")
+ # child4.text = "PASCAL VOC2007"
+ child5 = etree.SubElement(child4, "database")
+ child5.text = "Unknown"
+ #
+ # child6 = etree.SubElement(child3, "image")
+ # child6.text = "flickr"
+ # child7 = etree.SubElement(child3, "flickrid")
+ # child7.text = "35435"
+
+ def set_size(self, witdh, height, channel):
+ size = etree.SubElement(self.root, "size")
+ widthn = etree.SubElement(size, "width")
+ widthn.text = str(witdh)
+ heightn = etree.SubElement(size, "height")
+ heightn.text = str(height)
+ channeln = etree.SubElement(size, "depth")
+ channeln.text = str(channel)
+
+ def set_segmented(self,seg=0):
+ segmented = etree.SubElement(self.root,"segmented")
+ segmented.text = str(seg)
+
+
+ def savefile(self, filename):
+ tree = etree.ElementTree(self.root)
+ tree.write(filename, pretty_print=True, xml_declaration=False, encoding='utf-8')
+
+ def add_pic_attr(self, label, xmin, ymin, xmax, ymax):
+ object = etree.SubElement(self.root, "object")
+
+ namen = etree.SubElement(object, "name")
+ namen.text = label
+
+ pose = etree.SubElement(object,"pose")
+ pose.text = "Unspecified"
+
+ truncated = etree.SubElement(object,"truncated")
+ truncated.text = "0"
+
+ difficult = etree.SubElement(object,"difficult")
+ difficult.text = "0"
+
+ bndbox = etree.SubElement(object, "bndbox")
+ xminn = etree.SubElement(bndbox, "xmin")
+ xminn.text = str(xmin)
+ yminn = etree.SubElement(bndbox, "ymin")
+ yminn.text = str(ymin)
+ xmaxn = etree.SubElement(bndbox, "xmax")
+ xmaxn.text = str(xmax)
+ ymaxn = etree.SubElement(bndbox, "ymax")
+ ymaxn.text = str(ymax)
+
+
+import os
+import cv2
+
+
+def getFileList(dir, Filelist, ext=None):
+ """
+ 获取文件夹及其子文件夹中文件列表
+ 输入 dir:文件夹根目录
+ 输入 ext: 扩展名
+ 返回: 文件路径列表
+ """
+ newDir = dir
+ if os.path.isfile(dir):
+ if ext is None:
+ Filelist.append(dir)
+ else:
+ if ext in dir[-3:]:
+ Filelist.append(dir)
+
+ elif os.path.isdir(dir):
+ for s in os.listdir(dir):
+ newDir = os.path.join(dir, s)
+ getFileList(newDir, Filelist, ext)
+
+ return Filelist
+
+
+# org_img_folder = './org'
+
+# 检索文件
+# imglist = getFileList(org_img_folder, [], 'jpg')
+# print('本次执行检索到 ' + str(len(imglist)) + ' 张图像\n')
+
+# for imgpath in imglist:
+# imgname = os.path.splitext(os.path.basename(imgpath))[0]
+# img = cv2.imread(imgpath, cv2.IMREAD_COLOR)
+ # 对每幅图像执行相关操作
+
+if __name__ == '__main__':
+ org_img_folder = r'.\标注文件\mfcc\ss'
+
+ # 检索文件
+ imglist = getFileList(org_img_folder, [], 'png')
+ print('本次执行检索到 ' + str(len(imglist)) + ' 张图像\n')
+ #
+ # filename = imglist[0]
+ # name = filename.split('\\')
+ # # print(name)
+ # anno = GEN_Annotations(name[4],filename)
+ # anno.set_size(800, 550, 3)
+ # anno.set_segmented()
+ # for i in range(1):
+ # xmin = i + 1
+ # ymin = i + 1
+ # xmax = i + 799
+ # ymax = i + 549
+ # anno.add_pic_attr("Snoring", xmin, ymin, xmax, ymax)
+ # filename_saved = filename.split('.')
+ # # print(filename_saved)
+ # anno.savefile('.'+filename_saved[1]+".xml")
+
+ for imagepath in imglist:
+ filename = imagepath
+ name = filename.split('\\')
+ # print(name)
+ anno = GEN_Annotations(name[4], filename)
+ anno.set_size(800, 550, 3)
+ for i in range(1):
+ xmin = i + 99
+ ymin = i + 64
+ xmax = i + 724
+ ymax = i + 493
+ anno.add_pic_attr("Snoring", xmin, ymin, xmax, ymax)
+ # filename_saved = filename.split('.')
+ filename_saved=name[4].split('.')
+ path=r'E:\语音处理\频谱\VOC\mfcc/ss/'
+ anno.savefile(path + filename_saved[0] + ".xml")
\ No newline at end of file
diff --git a/autoVOC-new.py b/autoVOC-new.py
new file mode 100644
index 0000000..e645c28
--- /dev/null
+++ b/autoVOC-new.py
@@ -0,0 +1,153 @@
+from lxml import etree
+
+
+class GEN_Annotations:
+ def __init__(self, filename):
+ self.root = etree.Element("annotation")
+
+ child1 = etree.SubElement(self.root, "folder")
+ child1.text = "folder"
+
+ child2 = etree.SubElement(self.root, "filename")
+ child2.text = filename
+
+ # child3 = etree.SubElement(self.root,"path")
+ # child3.text = imgpath
+
+ child4 = etree.SubElement(self.root, "source")
+
+ # child4 = etree.SubElement(child3, "annotation")
+ # child4.text = "PASCAL VOC2007"
+ child5 = etree.SubElement(child4, "database")
+ child5.text = "Unknown"
+ #
+ # child6 = etree.SubElement(child3, "image")
+ # child6.text = "flickr"
+ # child7 = etree.SubElement(child3, "flickrid")
+ # child7.text = "35435"
+
+ def set_size(self, witdh, height, channel):
+ size = etree.SubElement(self.root, "size")
+ widthn = etree.SubElement(size, "width")
+ widthn.text = str(witdh)
+ heightn = etree.SubElement(size, "height")
+ heightn.text = str(height)
+ channeln = etree.SubElement(size, "depth")
+ channeln.text = str(channel)
+
+ def set_segmented(self,seg=0):
+ segmented = etree.SubElement(self.root,"segmented")
+ segmented.text = str(seg)
+
+
+ def savefile(self, filename):
+ tree = etree.ElementTree(self.root)
+ tree.write(filename, pretty_print=True, xml_declaration=False, encoding='utf-8')
+
+ def add_pic_attr(self, label, xmin, ymin, xmax, ymax):
+ object = etree.SubElement(self.root, "object")
+
+ namen = etree.SubElement(object, "name")
+ namen.text = label
+
+ pose = etree.SubElement(object,"pose")
+ pose.text = "Unspecified"
+
+ truncated = etree.SubElement(object,"truncated")
+ truncated.text = "0"
+
+ difficult = etree.SubElement(object,"difficult")
+ difficult.text = "0"
+
+ bndbox = etree.SubElement(object, "bndbox")
+ xminn = etree.SubElement(bndbox, "xmin")
+ xminn.text = str(xmin)
+ yminn = etree.SubElement(bndbox, "ymin")
+ yminn.text = str(ymin)
+ xmaxn = etree.SubElement(bndbox, "xmax")
+ xmaxn.text = str(xmax)
+ ymaxn = etree.SubElement(bndbox, "ymax")
+ ymaxn.text = str(ymax)
+
+
+import os
+import cv2
+
+
+def getFileList(dir, Filelist, ext=None):
+ """
+ 获取文件夹及其子文件夹中文件列表
+ 输入 dir:文件夹根目录
+ 输入 ext: 扩展名
+ 返回: 文件路径列表
+ """
+ newDir = dir
+ if os.path.isfile(dir):
+ if ext is None:
+ Filelist.append(dir)
+ else:
+ if ext in dir[-3:]:
+ Filelist.append(dir)
+
+ elif os.path.isdir(dir):
+ for s in os.listdir(dir):
+ newDir = os.path.join(dir, s)
+ getFileList(newDir, Filelist, ext)
+
+ return Filelist
+
+
+# org_img_folder = './org'
+
+# 检索文件
+# imglist = getFileList(org_img_folder, [], 'jpg')
+# print('本次执行检索到 ' + str(len(imglist)) + ' 张图像\n')
+
+# for imgpath in imglist:
+# imgname = os.path.splitext(os.path.basename(imgpath))[0]
+# img = cv2.imread(imgpath, cv2.IMREAD_COLOR)
+ # 对每幅图像执行相关操作
+
+if __name__ == '__main__':
+ org_img_folder = r'.\标注文件\wave\ss'
+
+ # 检索文件
+ imglist = getFileList(org_img_folder, [], 'jpg')
+ print('本次执行检索到 ' + str(len(imglist)) + ' 张图像\n')
+ #
+ # filename = imglist[0]
+ # name = filename.split('\\')
+ # # print(name)
+ # anno = GEN_Annotations(name[4],filename)
+ # anno.set_size(800, 550, 3)
+ # anno.set_segmented()
+ # for i in range(1):
+ # xmin = i + 1
+ # ymin = i + 1
+ # xmax = i + 799
+ # ymax = i + 549
+ # anno.add_pic_attr("Snoring", xmin, ymin, xmax, ymax)
+ # filename_saved = filename.split('.')
+ # # print(filename_saved)
+ # anno.savefile('.'+filename_saved[1]+".xml")
+
+ for imagepath in imglist:
+ filename = imagepath
+ name = filename.split('\\')
+ # print(name)
+ anno = GEN_Annotations(name[4])
+ anno.set_size(800, 550, 3)
+ anno.set_segmented()
+ for i in range(1):
+ xmin = 105
+ ymin = 72
+ xmax = 718
+ ymax = 486
+ # xmin = 99
+ # ymin = 64
+ # xmax = 724
+ # ymax = 493
+ anno.add_pic_attr("Snoring", xmin, ymin, xmax, ymax)
+ filename_saved = name[4].split('.')
+ path = r'E:\语音处理\频谱\anno\wave/'
+ anno.savefile(path + filename_saved[0] + ".xml")
\ No newline at end of file
diff --git a/autoVOC.py b/autoVOC.py
new file mode 100644
index 0000000..4c0039d
--- /dev/null
+++ b/autoVOC.py
@@ -0,0 +1,148 @@
+from lxml import etree
+
+
+class GEN_Annotations:
+ def __init__(self, filename, imgpath):
+ self.root = etree.Element("annotation")
+
+ child1 = etree.SubElement(self.root, "folder")
+ child1.text = "ss"
+
+ child2 = etree.SubElement(self.root, "filename")
+ child2.text = filename
+
+ child3 = etree.SubElement(self.root,"path")
+ child3.text = imgpath
+
+ child4 = etree.SubElement(self.root, "source")
+
+ # child4 = etree.SubElement(child3, "annotation")
+ # child4.text = "PASCAL VOC2007"
+ child5 = etree.SubElement(child4, "database")
+ child5.text = "Unknown"
+ #
+ # child6 = etree.SubElement(child3, "image")
+ # child6.text = "flickr"
+ # child7 = etree.SubElement(child3, "flickrid")
+ # child7.text = "35435"
+
+ def set_size(self, witdh, height, channel):
+ size = etree.SubElement(self.root, "size")
+ widthn = etree.SubElement(size, "width")
+ widthn.text = str(witdh)
+ heightn = etree.SubElement(size, "height")
+ heightn.text = str(height)
+ channeln = etree.SubElement(size, "depth")
+ channeln.text = str(channel)
+
+ def set_segmented(self,seg=0):
+ segmented = etree.SubElement(self.root,"segmented")
+ segmented.text = str(seg)
+
+
+ def savefile(self, filename):
+ tree = etree.ElementTree(self.root)
+ tree.write(filename, pretty_print=True, xml_declaration=False, encoding='utf-8')
+
+ def add_pic_attr(self, label, xmin, ymin, xmax, ymax):
+ object = etree.SubElement(self.root, "object")
+
+ namen = etree.SubElement(object, "name")
+ namen.text = label
+
+ pose = etree.SubElement(object,"pose")
+ pose.text = "Unspecified"
+
+ truncated = etree.SubElement(object,"truncated")
+ truncated.text = "0"
+
+ difficult = etree.SubElement(object,"difficult")
+ difficult.text = "0"
+
+ bndbox = etree.SubElement(object, "bndbox")
+ xminn = etree.SubElement(bndbox, "xmin")
+ xminn.text = str(xmin)
+ yminn = etree.SubElement(bndbox, "ymin")
+ yminn.text = str(ymin)
+ xmaxn = etree.SubElement(bndbox, "xmax")
+ xmaxn.text = str(xmax)
+ ymaxn = etree.SubElement(bndbox, "ymax")
+ ymaxn.text = str(ymax)
+
+
+import os
+import cv2
+
+
+def getFileList(dir, Filelist, ext=None):
+ """
+ 获取文件夹及其子文件夹中文件列表
+ 输入 dir:文件夹根目录
+ 输入 ext: 扩展名
+ 返回: 文件路径列表
+ """
+ newDir = dir
+ if os.path.isfile(dir):
+ if ext is None:
+ Filelist.append(dir)
+ else:
+ if ext in dir[-3:]:
+ Filelist.append(dir)
+
+ elif os.path.isdir(dir):
+ for s in os.listdir(dir):
+ newDir = os.path.join(dir, s)
+ getFileList(newDir, Filelist, ext)
+
+ return Filelist
+
+
+# org_img_folder = './org'
+
+# 检索文件
+# imglist = getFileList(org_img_folder, [], 'jpg')
+# print('本次执行检索到 ' + str(len(imglist)) + ' 张图像\n')
+
+# for imgpath in imglist:
+# imgname = os.path.splitext(os.path.basename(imgpath))[0]
+# img = cv2.imread(imgpath, cv2.IMREAD_COLOR)
+ # 对每幅图像执行相关操作
+
+if __name__ == '__main__':
+ org_img_folder = r'D:\snoring-dataset\Snoring Dataset\音频数据\标注文件\mfcc\no/'
+ files=os.listdir(org_img_folder)
+ # 检索文件
+ imglist = getFileList(org_img_folder, [], 'png')
+ print('本次执行检索到 ' + str(len(imglist)) + ' 张图像\n')
+ #
+ # filename = imglist[0]
+ # name = filename.split('\\')
+ # # print(name)
+ # anno = GEN_Annotations(name[4],filename)
+ # anno.set_size(800, 550, 3)
+ # anno.set_segmented()
+ # for i in range(1):
+ # xmin = i + 1
+ # ymin = i + 1
+ # xmax = i + 799
+ # ymax = i + 549
+ # anno.add_pic_attr("Snoring", xmin, ymin, xmax, ymax)
+ # filename_saved = filename.split('.')
+ # # print(filename_saved)
+ # anno.savefile('.'+filename_saved[1]+".xml")
+
+ for i,img in enumerate(org_img_folder):
+ filename = os.path.splitext(img)[0]
+ filetype = os.path.splitext(img)[1]
+ name = filename.split('\\')
+ # print(name)
+ anno = GEN_Annotations(name[4], filename)
+ anno.set_size(800, 550, 3)
+ for i in range(1):
+ xmin = i + 99
+ ymin = i + 64
+ xmax = i + 724
+ ymax = i + 493
+ anno.add_pic_attr("No Snoring", xmin, ymin, xmax, ymax)
+ filename_saved = filename.split('.')
+ anno.savefile('.' + filename_saved[1] + ".xml")
\ No newline at end of file
diff --git a/cut.py b/cut.py
new file mode 100644
index 0000000..81f68cb
--- /dev/null
+++ b/cut.py
@@ -0,0 +1,17 @@
+import os
+from PIL import Image
+import numpy as np
+
+rootimgs = 'D:\paper\\3low_light_image\compare_lowlighr_enchace\enhancement_image\MBLLEN\\'
+targetroot = 'D:\paper\\3low_light_image\compare_lowlighr_enchace\enhancement_image\\'
+savdir = 'D:\paper\\3low_light_image\compare_lowlighr_enchace\enhancement_image\\'
+file_imgs = os.listdir(rootimgs)
+
+for file_img in file_imgs:
+ imgpath = rootimgs + file_img
+ targetimg = targetroot + file_img
+ image = Image.open(imgpath) # 用PIL中的Image.open打开图像
+ image_arr = np.array(image) # 转化成numpy数组
+ image_tar = image_arr[:, int(image_arr.shape[1] / 3):int(2 * image_arr.shape[1] / 3), :]
+ im = Image.fromarray(image_tar)
+ im.save(targetimg)
\ No newline at end of file
diff --git a/data-number.py b/data-number.py
new file mode 100644
index 0000000..41736cd
--- /dev/null
+++ b/data-number.py
@@ -0,0 +1,66 @@
+import matplotlib.pyplot as plt
+import os
+from urllib import request, parse
+import json
+# 有道翻译:中文→英文
+def fy(i):
+ req_url = 'http://fanyi.youdao.com/translate' # 创建连接接口
+ # 创建要提交的数据
+ Form_Date = {}
+ Form_Date['i'] = i
+ Form_Date['doctype'] = 'json'
+ Form_Date['form'] = 'AUTO'
+ Form_Date['to'] = 'AUTO'
+ Form_Date['smartresult'] = 'dict'
+ Form_Date['client'] = 'fanyideskweb'
+ Form_Date['salt'] = '1526995097962'
+ Form_Date['sign'] = '8e4c4765b52229e1f3ad2e633af89c76'
+ Form_Date['version'] = '2.1'
+ Form_Date['keyform'] = 'fanyi.web'
+ Form_Date['action'] = 'FY_BY_REALTIME'
+ Form_Date['typoResult'] = 'false'
+
+ data = parse.urlencode(Form_Date).encode('utf-8') # 数据转换
+ response = request.urlopen(req_url, data) # 提交数据并解析
+ html = response.read().decode('utf-8') # 服务器返回结果读取
+ # print(html)
+ # 可以看出html是一个json格式
+ translate_results = json.loads(html) # 以json格式载入
+ translate_results = translate_results['translateResult'][0][0]['tgt'] # json格式调取
+ # print(translate_results) # 输出结果
+ return translate_results; # 返回结果
+#
+#
+# res = fy('this is a dog')
+# print(res) # 这是一只狗
+
+
+plt.style.use("seaborn")
+no_snore_path='D:/snoring-dataset/no-snore/'
+no_snore_path_dir=os.listdir(no_snore_path)
+no_snore_num=0
+count=0
+no_snore_typrname=[]
+for i,filename in enumerate(no_snore_path_dir):
+ oldname=filename
+ print(oldname)
+ newname = filename[6:]
+ newname=fy(newname)
+ print(newname)
+ Oldname=os.path.join(no_snore_path,oldname)
+ Newname=os.path.join(no_snore_path,newname)
+ os.rename(Oldname,Newname)
+ # no_snore_typrname.append(newname)
+ count+=1
+no_snore_num=count
+print(no_snore_num)
+# print(no_snore_typrname)
+# print(no_snore_num)
+# newnames=[]
+# for i,file in enumerate(no_snore_typrname):
+# oldname=file
+# newname=file[6:]
+# newname=fy(newname)
+# newnames.append(newname)
+# os.rename(oldname,newname)
+# print(newnames)
\ No newline at end of file
diff --git a/fft.py b/fft.py
new file mode 100644
index 0000000..6d84a85
--- /dev/null
+++ b/fft.py
@@ -0,0 +1,20 @@
+import librosa
+import librosa.display
+import matplotlib.pyplot as plt
+import pywt
+# 读取音频文件
+filepath = 'D:\snoring-dataset\Snoring Dataset/'
+filename = filepath + '000000.wav'
+x, sr = librosa.load(filename, sr=None) # x--音频时间序列(一维数组) ; sr--音频的采样率
+
+# STFT处理绘制声谱图
+
+X = librosa.stft(x)
+Xdb = librosa.amplitude_to_db(abs(X)) # X--二维数组数据
+
+plt.figure(figsize=(5, 5))
+librosa.display.specshow(Xdb, sr=sr, x_axis='time', y_axis='log')
+plt.colorbar()
+plt.title('STFT transform processing audio signal')
+plt.show()
+
diff --git a/lsxbbh.py b/lsxbbh.py
new file mode 100644
index 0000000..2b360db
--- /dev/null
+++ b/lsxbbh.py
@@ -0,0 +1,32 @@
+import pywt
+import matplotlib.pyplot as plt
+import numpy as np
+# 小波
+sampling_rate = 1024
+t = np.arange(0, 1.0, 1.0 / sampling_rate)
+f1 = 100
+f2 = 200
+f3 = 300
+f4 = 400
+data = np.piecewise(t, [t < 1, t < 0.8, t < 0.5, t < 0.3],
+ [lambda t: 400*np.sin(2 * np.pi * f4 * t),
+ lambda t: 300*np.sin(2 * np.pi * f3 * t),
+ lambda t: 200*np.sin(2 * np.pi * f2 * t),
+ lambda t: 100*np.sin(2 * np.pi * f1 * t)])
+wavename = 'cgau8'
+totalscal = 256
+fc = pywt.central_frequency(wavename)
+cparam = 2 * fc * totalscal
+scales = cparam / np.arange(totalscal, 1, -1)
+[cwtmatr, frequencies] = pywt.cwt(data, scales, wavename, 1.0 / sampling_rate)
+plt.figure(figsize=(8, 4))
+plt.subplot(211)
+plt.plot(t, data)
+plt.xlabel("t(s)")
+plt.title('shipinpu', fontsize=20)
+plt.subplot(212)
+plt.contourf(t, frequencies, abs(cwtmatr))
+plt.ylabel(u"prinv(Hz)")
+plt.xlabel(u"t(s)")
+plt.subplots_adjust(hspace=0.4)
+plt.show()
\ No newline at end of file
diff --git a/lxml+plot(1).py b/lxml+plot(1).py
new file mode 100644
index 0000000..f8592c4
--- /dev/null
+++ b/lxml+plot(1).py
@@ -0,0 +1,50 @@
+from lxml import etree
+with open(r'C:\Users\c9347\Desktop\hehehe\temp.xml', 'r', encoding='utf-8') as f:
+ str = f.read()
+# print(str)
+xml=etree.fromstring(str)
+# xml=etree.XML(str)
+# xml=etree.HTML(str)
+xml=etree.parse(r'C:\Users\c9347\Desktop\hehehe\temp.xml')
+
+print(etree.tostring(xml).decode('utf-8'))
+name=xml.xpath('/annotation/object/name/text()')
+print(name)
+
+# import matplotlib.pyplot as plt
+# # # 一个figure(画布)上,可以有多个区域axes(坐标系),
+# # # 我们在每个坐标系上绘图,也就是说每个axes(坐标系)中,都有axis(坐标轴)。
+# # # 如果绘制一个简单的小图形,我们可以不设置figure对象,使用默认创建的figure对象,
+# # # 当然我们也可以显示创建figure对象。如果一张figure画布上,需要绘制多个图形。
+# # # 那么就必须显示的创建figure对象,然后得到每个位置上的axes对象,进行对应位置上的图形绘制。
+# # # 定义fig
+# # fig = plt.figure()
+# # # 建立子图
+# # ax = fig.subplots(2,2) # 2*2
+# # fig, ax = plt.subplots(2,2)
+# # # 第一个图为
+# # ax[0,0].plot([1,2,5], [3,4,8],label='a')
+# # # 第二个图为
+# # ax[0,1].plot([1,2], [3,4])
+# # # 第三个图为
+# # ax[1,1].plot([1,2], [3,4])
+# # # 第四个图为
+# # ax[1,0].plot([1,2], [3,4])
+# # x1 = [0, 1, 2, 3]
+# # y1 = [3, 7, 5, 9]
+# # x2 = [0, 1, 2, 3]
+# # y2 = [6, 2, 13, 10]
+# #
+# # ax[0,0].plot(x1, y1,label='b')
+# # ax[0,0].plot(x2, y2,label='c')
+# # ax[0,0].xticks([0,2,4,6,8])
+# # plt.show()
+# counts=[1,5,4,7,5]
+# labels=['a','b','c','d','e']
+# fig,ax=plt.subplots()
+# # ax.pie(counts,labels=labels,colors=['red','blue','red','blue','red'])
+# # plt.show()
+# ax.barh(labels,counts)
+# for index,item in enumerate(counts):
+# ax.text(item+1,index,str(item))
+# plt.show()
\ No newline at end of file
diff --git a/mel.py b/mel.py
new file mode 100644
index 0000000..4eb861f
--- /dev/null
+++ b/mel.py
@@ -0,0 +1,197 @@
+import IPython
+import cv2
+import IPython.display
+
+import librosa
+import librosa.display
+
+from fastai.vision import *
+
+import os
+
+DATA = 'D:\CloudMusic/'
+# CSV_TRN_CURATED = DATA + 'train_curated.csv' # 训练数据集:文件名,标签
+# TRN_CURATED = DATA + 'train_curated' # 训练数据的图片位置。
+
+# Mel-spectrogram Dataset
+PREPROCESSED = os.path.join(DATA) # 生成数据的保存位置
+MELS_TRN_CURATED = os.path.join(PREPROCESSED, 'mels_train_curated.pkl') # 结果保存文件,图片保存成pkl.
+
+
+def read_audio(conf, pathname, trim_long_data):
+ """
+ librosa 是音频处理库,conf.sampling_rate 为采样率 44100
+ :param conf:
+ :param pathname:
+ :param trim_long_data:
+ :return:
+ """
+ y, sr = librosa.load(pathname, sr=conf.sampling_rate) # 将音频文件加载为浮点时间系列。
+ # trim silence
+ if 0 < len(y): # workaround: 0 length causes error
+ y, _ = librosa.effects.trim(y) # trim, top_db=default(60)
+ # make it unified length to conf.samples
+ if len(y) > conf.samples: # long enough 88200
+ if trim_long_data:
+ y = y[0:0 + conf.samples]
+ else: # pad blank
+ padding = conf.samples - len(y) # add padding at both ends 不够的话就补充。
+ offset = padding // 2
+ y = np.pad(y, (offset, conf.samples - len(y) - offset), conf.padmode)
+ return y
+
+
+def audio_to_melspectrogram(conf, audio):
+ """
+ 计算一个梅尔频谱系数图
+ :param conf:
+ :param audio:
+ :return:
+ """
+ spectrogram = librosa.feature.melspectrogram(audio,
+ sr=conf.sampling_rate,
+ n_mels=conf.n_mels,
+ hop_length=conf.hop_length,
+ n_fft=conf.n_fft,
+ fmin=conf.fmin,
+ fmax=conf.fmax)
+ spectrogram = librosa.power_to_db(spectrogram) # 转化频谱系数单位
+ spectrogram = spectrogram.astype(np.float32)
+ return spectrogram
+
+
+def show_melspectrogram(conf, mels, title='Log-frequency power spectrogram'):
+ """
+
+ :param conf:
+ :param mels:
+ :param title:
+ :return:
+ """
+ librosa.display.specshow(mels, x_axis='time', y_axis='mel',
+ sr=conf.sampling_rate, hop_length=conf.hop_length,
+ fmin=conf.fmin, fmax=conf.fmax)
+ plt.colorbar(format='%+2.0f dB')
+ plt.title(title)
+ plt.show()
+
+
+def read_as_melspectrogram(conf, pathname, trim_long_data, debug_display=False):
+ """
+ :param conf:
+ :param pathname:
+ :param trim_long_data:
+ :param debug_display:
+ :return:
+ """
+ x = read_audio(conf, pathname, trim_long_data)
+ mels = audio_to_melspectrogram(conf, x)
+ if debug_display:
+ IPython.display.display(IPython.display.Audio(x, rate=conf.sampling_rate))
+ show_melspectrogram(conf, mels)
+ return mels
+
+
+def mono_to_color(X, mean=None, std=None, norm_max=None, norm_min=None, eps=1e-6):
+ """
+
+ :param X:
+ :param mean:
+ :param std:
+ :param norm_max:
+ :param norm_min:
+ :param eps:
+ :return:
+ """
+ # Stack X as [X,X,X]
+ X = np.stack([X, X, X], axis=-1)
+
+ # Standardize
+ mean = mean or X.mean()
+ X = X - mean
+ std = std or X.std()
+ Xstd = X / (std + eps)
+ _min, _max = Xstd.min(), Xstd.max()
+ norm_max = norm_max or _max
+ norm_min = norm_min or _min
+ if (_max - _min) > eps:
+ # Normalize to [0, 255]
+ V = Xstd
+ V[V < norm_min] = norm_min
+ V[V > norm_max] = norm_max
+ V = 255 * (V - norm_min) / (norm_max - norm_min)
+ V = V.astype(np.uint8)
+ else:
+ # Just zero
+ V = np.zeros_like(Xstd, dtype=np.uint8)
+ return V
+
+
+def convert_wav_to_image(df, source):
+ """
+ ## 转化WAV文件为图片,返回包含图片的list。
+ :param df:
+ :param source:
+ :return:
+ """
+ X = []
+ for i, row in df.iterrows():
+ wav_path = os.path.join(source, str(row.fname)) # WAV文件路径
+ print(wav_path)
+ x = read_as_melspectrogram(conf, wav_path, trim_long_data=False) # 读取图像并转化成数组
+ x_color = mono_to_color(x) # 转化为三维图像
+ X.append(x_color)
+ return X
+
+
+def save_as_pkl_binary(obj, filename):
+ """Save object as pickle binary file.
+ Thanks to https://stackoverflow.com/questions/19201290/how-to-save-a-dictionary-to-a-file/32216025
+ """
+ with open(filename, 'wb') as f:
+ pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
+
+
+def convert_dataset(df, source_folder, filename):
+ """
+ 转化WAV文件为图片,并保存image。
+ :param df:
+ :param source_folder:
+ :param filename:
+ :return:
+ """
+ X = convert_wav_to_image(df, source=source_folder)
+ save_as_pkl_binary(X, filename)
+ print(f'Created {filename}')
+ return X
+
+
+class conf:
+ sampling_rate = 44100
+ duration = 2 # sec
+ hop_length = 347 * duration # to make time steps 128
+ fmin = 20
+ fmax = sampling_rate // 2
+ n_mels = 128
+ n_fft = n_mels * 20
+ padmode = 'constant'
+ samples = sampling_rate * duration
+
+
+def get_default_conf():
+ return conf
+
+
+def main():
+ trn_curated_df = pd.read_csv(CSV_TRN_CURATED)
+
+ # 获取配置参数
+ conf = get_default_conf()
+
+ # 转化数据集 128xN (N/128)*2=时长。
+ convert_dataset(trn_curated_df, TRN_CURATED, MELS_TRN_CURATED);
+ # convert_dataset(test_df, TEST, MELS_TEST);
+
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/melm.py b/melm.py
new file mode 100644
index 0000000..f3e19ad
--- /dev/null
+++ b/melm.py
@@ -0,0 +1,52 @@
+import matplotlib.pyplot as plt
+import librosa
+import librosa.display
+import numpy as np
+import sys
+
+
+# 读取音频wav文件
+audio_path = r"D:\CloudMusic\no/000000.wav"
+y, sr = librosa.load(audio_path, sr=None, mono=True)
+"""
+:param
+ path 音频路径
+ sr 采样率(默认22050,但是有重采样的功能)
+ mono 设置为true是单通道,否则是双通道
+ offset 音频读取的时间
+ duration 获取音频的时长
+
+:returns
+ y : 音频的信号值,类型是ndarray
+ sr : 采样率
+"""
+###############################################################################
+
+################################################################################
+# 03 使用librosa获取mel谱图
+n_mels = 64
+n_frames = 5
+n_fft = 1024
+hop_length = 512
+power = 2.0
+
+mel_spectrogram = librosa.feature.melspectrogram(y=y,
+ sr=sr,
+ n_fft=n_fft,
+ hop_length=hop_length,
+ n_mels=n_mels,
+ power=power)
+
+# librosa.display.specshow(librosa.power_to_db(mel_spectrogram, ref=np.max),
+# y_axis='mel', fmax=8000, x_axis='time')
+# plt.colorbar(format='%+2.0f dB')
+##################################################################################
+
+# 04 将mel谱图转换为log mel谱图
+log_mel_spectrogram = 20.0 / power * np.log10(np.maximum(mel_spectrogram, sys.float_info.epsilon))
+librosa.display.specshow(librosa.power_to_db(log_mel_spectrogram, ref=np.max),
+ y_axis='mel', fmax=8000, x_axis='time')
+# plt.colorbar(format='%+2.0f dB')
+##################################################################################
+
+plt.show()
diff --git a/mfcc.py b/mfcc.py
new file mode 100644
index 0000000..462b563
--- /dev/null
+++ b/mfcc.py
@@ -0,0 +1,21 @@
+import matplotlib.pyplot as plt
+import librosa
+import librosa.display
+plt.style.use('seaborn')
+
+y, sr = librosa.load('D:\CloudMusic/no/000000.wav', sr=16000)
+# 提取 mel spectrogram feature
+melspec = librosa.feature.melspectrogram(y, sr, n_fft=1024, hop_length=512, n_mels=128)
+logmelspec = librosa.power_to_db(melspec) # 转换为对数刻度
+# 绘制 mel 频谱图
+plt.figure()
+librosa.display.specshow(logmelspec, sr=sr, x_axis='time', y_axis='mel')
+plt.colorbar(format='%+2.0f dB') # 右边的色度条
+plt.title('Beat wavform')
+plt.show()
+
+# mfccs = librosa.feature.mfcc(y,)
+# plt.figure()
+# librosa.display.specshow(mfccs,)
+# plt.title('mfcc')
+# plt.show()
diff --git a/multi classification.py b/multi classification.py
new file mode 100644
index 0000000..ec77df8
--- /dev/null
+++ b/multi classification.py
@@ -0,0 +1,20 @@
+# coding=utf-8
+import os, random, shutil
+def moveFile(fileDir):
+ pathDir = os.listdir(fileDir) # 取图片的原始路径
+ filenumber = len(pathDir)
+ picknumber = int(filenumber * ratio) # 按照rate比例从文件夹中取一定数量图片
+ sample = random.sample(pathDir, picknumber) # 随机选取picknumber数量的样本图片
+ for name in sample:
+ shutil.move(os.path.join(fileDir, name), os.path.join(tarDir, name))
+ return
+if __name__ == '__main__':
+ ori_path = 'D:/snoring-dataset/Snoring Dataset/309-no-snoring-1' # 最开始train的文件夹路径
+ split_Dir = 'D:/snoring-dataset/Snoring Dataset/test-309' # 移动到新的文件夹路径
+ ratio = 0.1 # 抽取比例
+ for firstPath in os.listdir(ori_path):
+ fileDir = os.path.join(ori_path, firstPath) # 原图片文件夹路径
+ tarDir = os.path.join(split_Dir, firstPath) # val下子文件夹名字
+ if not os.path.exists(tarDir): # 如果val下没有子文件夹,就创建
+ os.makedirs(tarDir)
+ moveFile(fileDir) # 从每个子类别开始逐个划分
\ No newline at end of file
diff --git a/png-jpg.py b/png-jpg.py
new file mode 100644
index 0000000..349615e
--- /dev/null
+++ b/png-jpg.py
@@ -0,0 +1,16 @@
+import os
+# png文件路径
+png_path=r'E:\语音处理\频谱\标注文件\stft\no'
+# jpg_path=r'E:\语音处理\频谱'
+files=os.listdir(png_path)
+k=0
+for i,file in enumerate(files):
+ filename=os.path.splitext(file)[0]
+ filetype=os.path.splitext(file)[1]
+ if filetype=='.png':
+ old_name=os.path.join(png_path,file)
+ new_name=os.path.join(png_path,filename+'.jpg')
+ os.rename(old_name,new_name)
+ # print(old_name,new_name)
+ k+=1
+print(k)
\ No newline at end of file
diff --git a/renamefile.py b/renamefile.py
new file mode 100644
index 0000000..0c1a412
--- /dev/null
+++ b/renamefile.py
@@ -0,0 +1,17 @@
+import os
+path=r'D:\snoring-dataset\40test\402 - Mouse click'
+filedir=os.listdir(path)
+count=0
+for i,file in enumerate(filedir):
+ # 分割文件的文件名和扩展名
+ filename=os.path.splitext(file)[0]
+ filetype=os.path.splitext(file)[1]
+ # 判断文件类型
+ if filetype=='.ogg':
+ if count%10==0:
+ print(count)
+ oldname=os.path.join(path,file)
+ newname=os.path.join(path,str(count+3000).zfill(6)+'.wav')
+ os.rename(oldname,newname)
+ count+=1
+print(count)
diff --git a/resize.py b/resize.py
new file mode 100644
index 0000000..67528c2
--- /dev/null
+++ b/resize.py
@@ -0,0 +1,27 @@
+# encoding:utf-8
+
+# 用于重设图片大小,主要用来遇到图片大小限制时缩放图片
+
+import cv2
+
+if __name__ == '__main__':
+ img = cv2.imread('D:\snoring-dataset\Snoring Dataset\音频数据\标注文件\mfcc\ss/000000.png')
+ cv2.imshow('resize before', img)
+ # 直接指定目标图片大小
+ img = cv2.resize(img, (416, 416))
+
+ # 按比例缩小,例如缩小2倍
+ # 原图高
+ # height = img.shape[0]
+ # # 原图宽
+ # width = img.shape[1]
+ # # 元祖参数,为宽,高
+ # img = cv2.resize(img, (int(width / 2), int(height / 2)))
+
+ cv2.imshow('resize after', img)
+
+ # 写入新文件
+ cv2.imwrite('./2.jpg', img)
+ # 延迟关闭
+ cv2.waitKey()
+
diff --git a/t.py b/t.py
new file mode 100644
index 0000000..51aadb5
--- /dev/null
+++ b/t.py
@@ -0,0 +1,13 @@
+# 波形图
+# wave=thinkdsp.read_wave("D:\CloudMusic\ss/000000.wav")
+# wave.plot()
+# plt.savefig('D:\CloudMusic\ss/test1')
+# plt.show()
+# 频谱
+import thinkdsp
+from 频谱 import thinkplot
+
+wave= thinkdsp.read_wave("D:\CloudMusic\ss/000000.wav")
+spectrum=wave.make_spectrum()
+spectrum.plot()
+thinkplot.show()
diff --git a/thinkdsp.py b/thinkdsp.py
new file mode 100644
index 0000000..7d9d0a5
--- /dev/null
+++ b/thinkdsp.py
@@ -0,0 +1,1631 @@
+from __future__ import print_function, division
+
+import copy
+import math
+
+import numpy as np
+import random
+import scipy
+import scipy.stats
+import scipy.fftpack
+import subprocess
+from 频谱 import thinkplot
+import warnings
+
+from wave import open as open_wave
+
+import matplotlib.pyplot as pyplot
+
+try:
+ from IPython.display import Audio
+except:
+ warnings.warn(
+ "Can't import Audio from IPython.display; " "Wave.make_audio() will not work."
+ )
+
+PI2 = math.pi * 2
+
+
+def random_seed(x):
+ """Initialize the random and np.random generators.
+ x: int seed
+ """
+ random.seed(x)
+ np.random.seed(x)
+
+
+class UnimplementedMethodException(Exception):
+ """Exception if someone calls a method that should be overridden."""
+
+
+class WavFileWriter:
+ """Writes wav files."""
+
+ def __init__(self, filename="sound.wav", framerate=11025):
+ """Opens the file and sets parameters.
+ filename: string
+ framerate: samples per second
+ """
+ self.filename = filename
+ self.framerate = framerate
+ self.nchannels = 1
+ self.sampwidth = 2
+ self.bits = self.sampwidth * 8
+ self.bound = 2 ** (self.bits - 1) - 1
+
+ self.fmt = "h"
+ self.dtype = np.int16
+
+ self.fp = open_wave(self.filename, "w")
+ self.fp.setnchannels(self.nchannels)
+ self.fp.setsampwidth(self.sampwidth)
+ self.fp.setframerate(self.framerate)
+
+ def write(self, wave):
+ """Writes a wave.
+ wave: Wave
+ """
+ zs = wave.quantize(self.bound, self.dtype)
+ self.fp.writeframes(zs.tostring())
+
+ def close(self, duration=0):
+ """Closes the file.
+ duration: how many seconds of silence to append
+ """
+ if duration:
+ self.write(rest(duration))
+
+ self.fp.close()
+
+
+def read_wave(filename="sound.wav"):
+ """Reads a wave file.
+ filename: string
+ returns: Wave
+ """
+ fp = open_wave(filename, "r")
+
+ nchannels = fp.getnchannels()
+ nframes = fp.getnframes()
+ sampwidth = fp.getsampwidth()
+ framerate = fp.getframerate()
+
+ z_str = fp.readframes(nframes)
+
+ fp.close()
+
+ dtype_map = {1: np.int8, 2: np.int16, 3: "special", 4: np.int32}
+ if sampwidth not in dtype_map:
+ raise ValueError("sampwidth %d unknown" % sampwidth)
+
+ if sampwidth == 3:
+ xs = np.fromstring(z_str, dtype=np.int8).astype(np.int32)
+ ys = (xs[2::3] * 256 + xs[1::3]) * 256 + xs[0::3]
+ else:
+ ys = np.fromstring(z_str, dtype=dtype_map[sampwidth])
+
+ # if it's in stereo, just pull out the first channel
+ if nchannels == 2:
+ ys = ys[::2]
+
+ # ts = np.arange(len(ys)) / framerate
+ wave = Wave(ys, framerate=framerate)
+ wave.normalize()
+ return wave
+
+
+def play_wave(filename="sound.wav", player="aplay"):
+ """Plays a wave file.
+ filename: string
+ player: string name of executable that plays wav files
+ """
+ cmd = "%s %s" % (player, filename)
+ popen = subprocess.Popen(cmd, shell=True)
+ popen.communicate()
+
+
+def find_index(x, xs):
+ """Find the index corresponding to a given value in an array."""
+ n = len(xs)
+ start = xs[0]
+ end = xs[-1]
+ i = round((n - 1) * (x - start) / (end - start))
+ return int(i)
+
+
+class _SpectrumParent:
+ """Contains code common to Spectrum and DCT.
+ """
+
+ def __init__(self, hs, fs, framerate, full=False):
+ """Initializes a spectrum.
+ hs: array of amplitudes (real or complex)
+ fs: array of frequencies
+ framerate: frames per second
+ full: boolean to indicate full or real FFT
+ """
+ self.hs = np.asanyarray(hs)
+ self.fs = np.asanyarray(fs)
+ self.framerate = framerate
+ self.full = full
+
+ @property
+ def max_freq(self):
+ """Returns the Nyquist frequency for this spectrum."""
+ return self.framerate / 2
+
+ @property
+ def amps(self):
+ """Returns a sequence of amplitudes (read-only property)."""
+ return np.absolute(self.hs)
+
+ @property
+ def power(self):
+ """Returns a sequence of powers (read-only property)."""
+ return self.amps ** 2
+
+ def copy(self):
+ """Makes a copy.
+ Returns: new Spectrum
+ """
+ return copy.deepcopy(self)
+
+ def max_diff(self, other):
+ """Computes the maximum absolute difference between spectra.
+ other: Spectrum
+ returns: float
+ """
+ assert self.framerate == other.framerate
+ assert len(self) == len(other)
+
+ hs = self.hs - other.hs
+ return np.max(np.abs(hs))
+
+ def ratio(self, denom, thresh=1, val=0):
+ """The ratio of two spectrums.
+ denom: Spectrum
+ thresh: values smaller than this are replaced
+ val: with this value
+ returns: new Wave
+ """
+ ratio_spectrum = self.copy()
+ ratio_spectrum.hs /= denom.hs
+ ratio_spectrum.hs[denom.amps < thresh] = val
+ return ratio_spectrum
+
+ def invert(self):
+ """Inverts this spectrum/filter.
+ returns: new Wave
+ """
+ inverse = self.copy()
+ inverse.hs = 1 / inverse.hs
+ return inverse
+
+ @property
+ def freq_res(self):
+ return self.framerate / 2 / (len(self.fs) - 1)
+
+ def render_full(self, high=None):
+ """Extracts amps and fs from a full spectrum.
+ high: cutoff frequency
+ returns: fs, amps
+ """
+ hs = np.fft.fftshift(self.hs)
+ amps = np.abs(hs)
+ fs = np.fft.fftshift(self.fs)
+ i = 0 if high is None else find_index(-high, fs)
+ j = None if high is None else find_index(high, fs) + 1
+ return fs[i:j], amps[i:j]
+
+ def plot(self, high=None, **options):
+ """Plots amplitude vs frequency.
+ Note: if this is a full spectrum, it ignores low and high
+ high: frequency to cut off at
+ """
+ if self.full:
+ fs, amps = self.render_full(high)
+ thinkplot.plot(fs, amps, **options)
+ else:
+ i = None if high is None else find_index(high, self.fs)
+ thinkplot.plot(self.fs[:i], self.amps[:i], **options)
+
+ def plot_power(self, high=None, **options):
+ """Plots power vs frequency.
+ high: frequency to cut off at
+ """
+ if self.full:
+ fs, amps = self.render_full(high)
+ thinkplot.plot(fs, amps ** 2, **options)
+ else:
+ i = None if high is None else find_index(high, self.fs)
+ thinkplot.plot(self.fs[:i], self.power[:i], **options)
+
+ def estimate_slope(self):
+ """Runs linear regression on log power vs log frequency.
+ returns: slope, inter, r2, p, stderr
+ """
+ x = np.log(self.fs[1:])
+ y = np.log(self.power[1:])
+ t = scipy.stats.linregress(x, y)
+ return t
+
+ def peaks(self):
+ """Finds the highest peaks and their frequencies.
+ returns: sorted list of (amplitude, frequency) pairs
+ """
+ t = list(zip(self.amps, self.fs))
+ t.sort(reverse=True)
+ return t
+
+
+class Spectrum(_SpectrumParent):
+ """Represents the spectrum of a signal."""
+
+ def __len__(self):
+ """Length of the spectrum."""
+ return len(self.hs)
+
+ def __add__(self, other):
+ """Adds two spectrums elementwise.
+ other: Spectrum
+ returns: new Spectrum
+ """
+ if other == 0:
+ return self.copy()
+
+ assert all(self.fs == other.fs)
+ hs = self.hs + other.hs
+ return Spectrum(hs, self.fs, self.framerate, self.full)
+
+ __radd__ = __add__
+
+ def __mul__(self, other):
+ """Multiplies two spectrums elementwise.
+ other: Spectrum
+ returns: new Spectrum
+ """
+ assert all(self.fs == other.fs)
+ hs = self.hs * other.hs
+ return Spectrum(hs, self.fs, self.framerate, self.full)
+
+ def convolve(self, other):
+ """Convolves two Spectrums.
+ other: Spectrum
+ returns: Spectrum
+ """
+ assert all(self.fs == other.fs)
+ if self.full:
+ hs1 = np.fft.fftshift(self.hs)
+ hs2 = np.fft.fftshift(other.hs)
+ hs = np.convolve(hs1, hs2, mode="same")
+ hs = np.fft.ifftshift(hs)
+ else:
+ # not sure this branch would mean very much
+ hs = np.convolve(self.hs, other.hs, mode="same")
+
+ return Spectrum(hs, self.fs, self.framerate, self.full)
+
+ @property
+ def real(self):
+ """Returns the real part of the hs (read-only property)."""
+ return np.real(self.hs)
+
+ @property
+ def imag(self):
+ """Returns the imaginary part of the hs (read-only property)."""
+ return np.imag(self.hs)
+
+ @property
+ def angles(self):
+ """Returns a sequence of angles (read-only property)."""
+ return np.angle(self.hs)
+
+ def scale(self, factor):
+ """Multiplies all elements by the given factor.
+ factor: what to multiply the magnitude by (could be complex)
+ """
+ self.hs *= factor
+
+ def low_pass(self, cutoff, factor=0):
+ """Attenuate frequencies above the cutoff.
+ cutoff: frequency in Hz
+ factor: what to multiply the magnitude by
+ """
+ self.hs[abs(self.fs) > cutoff] *= factor
+
+ def high_pass(self, cutoff, factor=0):
+ """Attenuate frequencies below the cutoff.
+ cutoff: frequency in Hz
+ factor: what to multiply the magnitude by
+ """
+ self.hs[abs(self.fs) < cutoff] *= factor
+
+ def band_stop(self, low_cutoff, high_cutoff, factor=0):
+ """Attenuate frequencies between the cutoffs.
+ low_cutoff: frequency in Hz
+ high_cutoff: frequency in Hz
+ factor: what to multiply the magnitude by
+ """
+ # TODO: test this function
+ fs = abs(self.fs)
+ indices = (low_cutoff < fs) & (fs < high_cutoff)
+ self.hs[indices] *= factor
+
+ def pink_filter(self, beta=1):
+ """Apply a filter that would make white noise pink.
+ beta: exponent of the pink noise
+ """
+ denom = self.fs ** (beta / 2.0)
+ denom[0] = 1
+ self.hs /= denom
+
+ def differentiate(self):
+ """Apply the differentiation filter.
+ returns: new Spectrum
+ """
+ new = self.copy()
+ new.hs *= PI2 * 1j * new.fs
+ return new
+
+ def integrate(self):
+ """Apply the integration filter.
+ returns: new Spectrum
+ """
+ new = self.copy()
+ new.hs /= PI2 * 1j * new.fs
+ return new
+
+ def make_integrated_spectrum(self):
+ """Makes an integrated spectrum.
+ """
+ cs = np.cumsum(self.power)
+ cs /= cs[-1]
+ return IntegratedSpectrum(cs, self.fs)
+
+ def make_wave(self):
+ """Transforms to the time domain.
+ returns: Wave
+ """
+ if self.full:
+ ys = np.fft.ifft(self.hs)
+ else:
+ ys = np.fft.irfft(self.hs)
+
+ # NOTE: whatever the start time was, we lose it when
+ # we transform back; we could fix that by saving start
+ # time in the Spectrum
+ # ts = self.start + np.arange(len(ys)) / self.framerate
+ return Wave(ys, framerate=self.framerate)
+
+
+class IntegratedSpectrum:
+ """Represents the integral of a spectrum."""
+
+ def __init__(self, cs, fs):
+ """Initializes an integrated spectrum:
+ cs: sequence of cumulative amplitudes
+ fs: sequence of frequencies
+ """
+ self.cs = np.asanyarray(cs)
+ self.fs = np.asanyarray(fs)
+
+ def plot_power(self, low=0, high=None, expo=False, **options):
+ """Plots the integrated spectrum.
+ low: int index to start at
+ high: int index to end at
+ """
+ cs = self.cs[low:high]
+ fs = self.fs[low:high]
+
+ if expo:
+ cs = np.exp(cs)
+
+ thinkplot.plot(fs, cs, **options)
+
+ def estimate_slope(self, low=1, high=-12000):
+ """Runs linear regression on log cumulative power vs log frequency.
+ returns: slope, inter, r2, p, stderr
+ """
+ # print self.fs[low:high]
+ # print self.cs[low:high]
+ x = np.log(self.fs[low:high])
+ y = np.log(self.cs[low:high])
+ t = scipy.stats.linregress(x, y)
+ return t
+
+
+class Dct(_SpectrumParent):
+ """Represents the spectrum of a signal using discrete cosine transform."""
+
+ @property
+ def amps(self):
+ """Returns a sequence of amplitudes (read-only property).
+ Note: for DCTs, amps are positive or negative real.
+ """
+ return self.hs
+
+ def __add__(self, other):
+ """Adds two DCTs elementwise.
+ other: DCT
+ returns: new DCT
+ """
+ if other == 0:
+ return self
+
+ assert self.framerate == other.framerate
+ hs = self.hs + other.hs
+ return Dct(hs, self.fs, self.framerate)
+
+ __radd__ = __add__
+
+ def make_wave(self):
+ """Transforms to the time domain.
+ returns: Wave
+ """
+ N = len(self.hs)
+ ys = scipy.fftpack.idct(self.hs, type=2) / 2 / N
+ # NOTE: whatever the start time was, we lose it when
+ # we transform back
+ # ts = self.start + np.arange(len(ys)) / self.framerate
+ return Wave(ys, framerate=self.framerate)
+
+
+class Spectrogram:
+ """Represents the spectrum of a signal."""
+
+ def __init__(self, spec_map, seg_length):
+ """Initialize the spectrogram.
+ spec_map: map from float time to Spectrum
+ seg_length: number of samples in each segment
+ """
+ self.spec_map = spec_map
+ self.seg_length = seg_length
+
+ def any_spectrum(self):
+ """Returns an arbitrary spectrum from the spectrogram."""
+ index = next(iter(self.spec_map))
+ return self.spec_map[index]
+
+ @property
+ def time_res(self):
+ """Time resolution in seconds."""
+ spectrum = self.any_spectrum()
+ return float(self.seg_length) / spectrum.framerate
+
+ @property
+ def freq_res(self):
+ """Frequency resolution in Hz."""
+ return self.any_spectrum().freq_res
+
+ def times(self):
+ """Sorted sequence of times.
+ returns: sequence of float times in seconds
+ """
+ ts = sorted(iter(self.spec_map))
+ return ts
+
+ def frequencies(self):
+ """Sequence of frequencies.
+ returns: sequence of float freqencies in Hz.
+ """
+ fs = self.any_spectrum().fs
+ return fs
+
+ def plot(self, high=None, **options):
+ """Make a pseudocolor plot.
+ high: highest frequency component to plot
+ """
+ fs = self.frequencies()
+ i = None if high is None else find_index(high, fs)
+ fs = fs[:i]
+ ts = self.times()
+
+ # make the array
+ size = len(fs), len(ts)
+ array = np.zeros(size, dtype=np.float)
+
+ # copy amplitude from each spectrum into a column of the array
+ for j, t in enumerate(ts):
+ spectrum = self.spec_map[t]
+ array[:, j] = spectrum.amps[:i]
+
+ thinkplot.pcolor(ts, fs, array, **options)
+
+ def make_wave(self):
+ """Inverts the spectrogram and returns a Wave.
+ returns: Wave
+ """
+ res = []
+ for t, spectrum in sorted(self.spec_map.items()):
+ wave = spectrum.make_wave()
+ n = len(wave)
+
+ window = 1 / np.hamming(n)
+ wave.window(window)
+
+ i = wave.find_index(t)
+ start = i - n // 2
+ end = start + n
+ res.append((start, end, wave))
+
+ starts, ends, waves = zip(*res)
+ low = min(starts)
+ high = max(ends)
+
+ ys = np.zeros(high - low, np.float)
+ for start, end, wave in res:
+ ys[start:end] = wave.ys
+
+ # ts = np.arange(len(ys)) / self.framerate
+ return Wave(ys, framerate=wave.framerate)
+
+
+class Wave:
+ """Represents a discrete-time waveform.
+ """
+
+ def __init__(self, ys, ts=None, framerate=None):
+ """Initializes the wave.
+ ys: wave array
+ ts: array of times
+ framerate: samples per second
+ """
+ self.ys = np.asanyarray(ys)
+ self.framerate = framerate if framerate is not None else 11025
+
+ if ts is None:
+ self.ts = np.arange(len(ys)) / self.framerate
+ else:
+ self.ts = np.asanyarray(ts)
+
+ def copy(self):
+ """Makes a copy.
+ Returns: new Wave
+ """
+ return copy.deepcopy(self)
+
+ def __len__(self):
+ return len(self.ys)
+
+ @property
+ def start(self):
+ return self.ts[0]
+
+ @property
+ def end(self):
+ return self.ts[-1]
+
+ @property
+ def duration(self):
+ """Duration (property).
+ returns: float duration in seconds
+ """
+ return len(self.ys) / self.framerate
+
+ def __add__(self, other):
+ """Adds two waves elementwise.
+ other: Wave
+ returns: new Wave
+ """
+ if other == 0:
+ return self
+
+ assert self.framerate == other.framerate
+
+ # make an array of times that covers both waves
+ start = min(self.start, other.start)
+ end = max(self.end, other.end)
+ n = int(round((end - start) * self.framerate)) + 1
+ ys = np.zeros(n)
+ ts = start + np.arange(n) / self.framerate
+
+ def add_ys(wave):
+ i = find_index(wave.start, ts)
+
+ # make sure the arrays line up reasonably well
+ diff = ts[i] - wave.start
+ dt = 1 / wave.framerate
+ if (diff / dt) > 0.1:
+ warnings.warn(
+ "Can't add these waveforms; their " "time arrays don't line up."
+ )
+
+ j = i + len(wave)
+ ys[i:j] += wave.ys
+
+ add_ys(self)
+ add_ys(other)
+
+ return Wave(ys, ts, self.framerate)
+
+ __radd__ = __add__
+
+ def __or__(self, other):
+ """Concatenates two waves.
+ other: Wave
+ returns: new Wave
+ """
+ if self.framerate != other.framerate:
+ raise ValueError("Wave.__or__: framerates do not agree")
+
+ ys = np.concatenate((self.ys, other.ys))
+ # ts = np.arange(len(ys)) / self.framerate
+ return Wave(ys, framerate=self.framerate)
+
+ def __mul__(self, other):
+ """Multiplies two waves elementwise.
+ Note: this operation ignores the timestamps; the result
+ has the timestamps of self.
+ other: Wave
+ returns: new Wave
+ """
+ # the spectrums have to have the same framerate and duration
+ assert self.framerate == other.framerate
+ assert len(self) == len(other)
+
+ ys = self.ys * other.ys
+ return Wave(ys, self.ts, self.framerate)
+
+ def max_diff(self, other):
+ """Computes the maximum absolute difference between waves.
+ other: Wave
+ returns: float
+ """
+ assert self.framerate == other.framerate
+ assert len(self) == len(other)
+
+ ys = self.ys - other.ys
+ return np.max(np.abs(ys))
+
+ def convolve(self, other):
+ """Convolves two waves.
+ Note: this operation ignores the timestamps; the result
+ has the timestamps of self.
+ other: Wave or NumPy array
+ returns: Wave
+ """
+ if isinstance(other, Wave):
+ assert self.framerate == other.framerate
+ window = other.ys
+ else:
+ window = other
+
+ ys = np.convolve(self.ys, window, mode="full")
+ # ts = np.arange(len(ys)) / self.framerate
+ return Wave(ys, framerate=self.framerate)
+
+ def diff(self):
+ """Computes the difference between successive elements.
+ returns: new Wave
+ """
+ ys = np.diff(self.ys)
+ ts = self.ts[1:].copy()
+ return Wave(ys, ts, self.framerate)
+
+ def cumsum(self):
+ """Computes the cumulative sum of the elements.
+ returns: new Wave
+ """
+ ys = np.cumsum(self.ys)
+ ts = self.ts.copy()
+ return Wave(ys, ts, self.framerate)
+
+ def quantize(self, bound, dtype):
+ """Maps the waveform to quanta.
+ bound: maximum amplitude
+ dtype: numpy data type or string
+ returns: quantized signal
+ """
+ return quantize(self.ys, bound, dtype)
+
+ def apodize(self, denom=20, duration=0.1):
+ """Tapers the amplitude at the beginning and end of the signal.
+ Tapers either the given duration of time or the given
+ fraction of the total duration, whichever is less.
+ denom: float fraction of the segment to taper
+ duration: float duration of the taper in seconds
+ """
+ self.ys = apodize(self.ys, self.framerate, denom, duration)
+
+ def hamming(self):
+ """Apply a Hamming window to the wave.
+ """
+ self.ys *= np.hamming(len(self.ys))
+
+ def window(self, window):
+ """Apply a window to the wave.
+ window: sequence of multipliers, same length as self.ys
+ """
+ self.ys *= window
+
+ def scale(self, factor):
+ """Multplies the wave by a factor.
+ factor: scale factor
+ """
+ self.ys *= factor
+
+ def shift(self, shift):
+ """Shifts the wave left or right in time.
+ shift: float time shift
+ """
+ # TODO: track down other uses of this function and check them
+ self.ts += shift
+
+ def roll(self, roll):
+ """Rolls this wave by the given number of locations.
+ """
+ self.ys = np.roll(self.ys, roll)
+
+ def truncate(self, n):
+ """Trims this wave to the given length.
+ n: integer index
+ """
+ self.ys = truncate(self.ys, n)
+ self.ts = truncate(self.ts, n)
+
+ def zero_pad(self, n):
+ """Trims this wave to the given length.
+ n: integer index
+ """
+ self.ys = zero_pad(self.ys, n)
+ self.ts = self.start + np.arange(n) / self.framerate
+
+ def normalize(self, amp=1.0):
+ """Normalizes the signal to the given amplitude.
+ amp: float amplitude
+ """
+ self.ys = normalize(self.ys, amp=amp)
+
+ def unbias(self):
+ """Unbiases the signal.
+ """
+ self.ys = unbias(self.ys)
+
+ def find_index(self, t):
+ """Find the index corresponding to a given time."""
+ n = len(self)
+ start = self.start
+ end = self.end
+ i = round((n - 1) * (t - start) / (end - start))
+ return int(i)
+
+ def segment(self, start=None, duration=None):
+ """Extracts a segment.
+ start: float start time in seconds
+ duration: float duration in seconds
+ returns: Wave
+ """
+ if start is None:
+ start = self.ts[0]
+ i = 0
+ else:
+ i = self.find_index(start)
+
+ j = None if duration is None else self.find_index(start + duration)
+ return self.slice(i, j)
+
+ def slice(self, i, j):
+ """Makes a slice from a Wave.
+ i: first slice index
+ j: second slice index
+ """
+ ys = self.ys[i:j].copy()
+ ts = self.ts[i:j].copy()
+ return Wave(ys, ts, self.framerate)
+
+ def make_spectrum(self, full=False):
+ """Computes the spectrum using FFT.
+ returns: Spectrum
+ """
+ n = len(self.ys)
+ d = 1 / self.framerate
+
+ if full:
+ hs = np.fft.fft(self.ys)
+ fs = np.fft.fftfreq(n, d)
+ else:
+ hs = np.fft.rfft(self.ys)
+ fs = np.fft.rfftfreq(n, d)
+
+ return Spectrum(hs, fs, self.framerate, full)
+
+ def make_dct(self):
+ """Computes the DCT of this wave.
+ """
+ N = len(self.ys)
+ hs = scipy.fftpack.dct(self.ys, type=2)
+ fs = (0.5 + np.arange(N)) / 2
+ return Dct(hs, fs, self.framerate)
+
+ def make_spectrogram(self, seg_length, win_flag=True):
+ """Computes the spectrogram of the wave.
+ seg_length: number of samples in each segment
+ win_flag: boolean, whether to apply hamming window to each segment
+ returns: Spectrogram
+ """
+ if win_flag:
+ window = np.hamming(seg_length)
+ i, j = 0, seg_length
+ step = int(seg_length // 2)
+
+ # map from time to Spectrum
+ spec_map = {}
+
+ while j < len(self.ys):
+ segment = self.slice(i, j)
+ if win_flag:
+ segment.window(window)
+
+ # the nominal time for this segment is the midpoint
+ t = (segment.start + segment.end) / 2
+ spec_map[t] = segment.make_spectrum()
+
+ i += step
+ j += step
+
+ return Spectrogram(spec_map, seg_length)
+
+ def get_xfactor(self, options):
+ try:
+ xfactor = options["xfactor"]
+ options.pop("xfactor")
+ except KeyError:
+ xfactor = 1
+ return xfactor
+
+ def plot(self, **options):
+ """Plots the wave.
+ """
+ xfactor = self.get_xfactor(options)
+ thinkplot.plot(self.ts * xfactor, self.ys, **options)
+
+ def plot_vlines(self, **options):
+ """Plots the wave with vertical lines for samples.
+ """
+ xfactor = self.get_xfactor(options)
+ thinkplot.vlines(self.ts * xfactor, 0, self.ys, **options)
+
+ def corr(self, other):
+ """Correlation coefficient two waves.
+ other: Wave
+ returns: float coefficient of correlation
+ """
+ corr = np.corrcoef(self.ys, other.ys)[0, 1]
+ return corr
+
+ def cov_mat(self, other):
+ """Covariance matrix of two waves.
+ other: Wave
+ returns: 2x2 covariance matrix
+ """
+ return np.cov(self.ys, other.ys)
+
+ def cov(self, other):
+ """Covariance of two unbiased waves.
+ other: Wave
+ returns: float
+ """
+ total = sum(self.ys * other.ys) / len(self.ys)
+ return total
+
+ def cos_cov(self, k):
+ """Covariance with a cosine signal.
+ freq: freq of the cosine signal in Hz
+ returns: float covariance
+ """
+ n = len(self.ys)
+ factor = math.pi * k / n
+ ys = [math.cos(factor * (i + 0.5)) for i in range(n)]
+ total = 2 * sum(self.ys * ys)
+ return total
+
+ def cos_transform(self):
+ """Discrete cosine transform.
+ returns: list of frequency, cov pairs
+ """
+ n = len(self.ys)
+ res = []
+ for k in range(n):
+ cov = self.cos_cov(k)
+ res.append((k, cov))
+
+ return res
+
+ def write(self, filename="sound.wav"):
+ """Write a wave file.
+ filename: string
+ """
+ print("Writing", filename)
+ wfile = WavFileWriter(filename, self.framerate)
+ wfile.write(self)
+ wfile.close()
+
+ def play(self, filename="sound.wav"):
+ """Plays a wave file.
+ filename: string
+ """
+ self.write(filename)
+ play_wave(filename)
+
+ def make_audio(self):
+ """Makes an IPython Audio object.
+ """
+ audio = Audio(data=self.ys.real, rate=self.framerate)
+ return audio
+
+
+def unbias(ys):
+ """Shifts a wave array so it has mean 0.
+ ys: wave array
+ returns: wave array
+ """
+ return ys - ys.mean()
+
+
+def normalize(ys, amp=1.0):
+ """Normalizes a wave array so the maximum amplitude is +amp or -amp.
+ ys: wave array
+ amp: max amplitude (pos or neg) in result
+ returns: wave array
+ """
+ high, low = abs(max(ys)), abs(min(ys))
+ return amp * ys / max(high, low)
+
+
+def shift_right(ys, shift):
+ """Shifts a wave array to the right and zero pads.
+ ys: wave array
+ shift: integer shift
+ returns: wave array
+ """
+ res = np.zeros(len(ys) + shift)
+ res[shift:] = ys
+ return res
+
+
+def shift_left(ys, shift):
+ """Shifts a wave array to the left.
+ ys: wave array
+ shift: integer shift
+ returns: wave array
+ """
+ return ys[shift:]
+
+
+def truncate(ys, n):
+ """Trims a wave array to the given length.
+ ys: wave array
+ n: integer length
+ returns: wave array
+ """
+ return ys[:n]
+
+
+def quantize(ys, bound, dtype):
+ """Maps the waveform to quanta.
+ ys: wave array
+ bound: maximum amplitude
+ dtype: numpy data type of the result
+ returns: quantized signal
+ """
+ if max(ys) > 1 or min(ys) < -1:
+ warnings.warn("Warning: normalizing before quantizing.")
+ ys = normalize(ys)
+
+ zs = (ys * bound).astype(dtype)
+ return zs
+
+
+def apodize(ys, framerate, denom=20, duration=0.1):
+ """Tapers the amplitude at the beginning and end of the signal.
+ Tapers either the given duration of time or the given
+ fraction of the total duration, whichever is less.
+ ys: wave array
+ framerate: int frames per second
+ denom: float fraction of the segment to taper
+ duration: float duration of the taper in seconds
+ returns: wave array
+ """
+ # a fixed fraction of the segment
+ n = len(ys)
+ k1 = n // denom
+
+ # a fixed duration of time
+ k2 = int(duration * framerate)
+
+ k = min(k1, k2)
+
+ w1 = np.linspace(0, 1, k)
+ w2 = np.ones(n - 2 * k)
+ w3 = np.linspace(1, 0, k)
+
+ window = np.concatenate((w1, w2, w3))
+ return ys * window
+
+
+class Signal:
+ """Represents a time-varying signal."""
+
+ def __add__(self, other):
+ """Adds two signals.
+ other: Signal
+ returns: Signal
+ """
+ if other == 0:
+ return self
+ return SumSignal(self, other)
+
+ __radd__ = __add__
+
+ @property
+ def period(self):
+ """Period of the signal in seconds (property).
+ Since this is used primarily for purposes of plotting,
+ the default behavior is to return a value, 0.1 seconds,
+ that is reasonable for many signals.
+ returns: float seconds
+ """
+ return 0.1
+
+ def plot(self, framerate=11025):
+ """Plots the signal.
+ The default behavior is to plot three periods.
+ framerate: samples per second
+ """
+ duration = self.period * 3
+ wave = self.make_wave(duration, start=0, framerate=framerate)
+ wave.plot()
+
+ def make_wave(self, duration=1, start=0, framerate=11025):
+ """Makes a Wave object.
+ duration: float seconds
+ start: float seconds
+ framerate: int frames per second
+ returns: Wave
+ """
+ n = round(duration * framerate)
+ ts = start + np.arange(n) / framerate
+ ys = self.evaluate(ts)
+ return Wave(ys, ts, framerate=framerate)
+
+
+def infer_framerate(ts):
+ """Given ts, find the framerate.
+ Assumes that the ts are equally spaced.
+ ts: sequence of times in seconds
+ returns: frames per second
+ """
+ # TODO: confirm that this is never used and remove it
+ dt = ts[1] - ts[0]
+ framerate = 1.0 / dt
+ return framerate
+
+
+class SumSignal(Signal):
+ """Represents the sum of signals."""
+
+ def __init__(self, *args):
+ """Initializes the sum.
+ args: tuple of signals
+ """
+ self.signals = args
+
+ @property
+ def period(self):
+ """Period of the signal in seconds.
+ Note: this is not correct; it's mostly a placekeeper.
+ But it is correct for a harmonic sequence where all
+ component frequencies are multiples of the fundamental.
+ returns: float seconds
+ """
+ return max(sig.period for sig in self.signals)
+
+ def evaluate(self, ts):
+ """Evaluates the signal at the given times.
+ ts: float array of times
+ returns: float wave array
+ """
+ ts = np.asarray(ts)
+ return sum(sig.evaluate(ts) for sig in self.signals)
+
+
+class Sinusoid(Signal):
+ """Represents a sinusoidal signal."""
+
+ def __init__(self, freq=440, amp=1.0, offset=0, func=np.sin):
+ """Initializes a sinusoidal signal.
+ freq: float frequency in Hz
+ amp: float amplitude, 1.0 is nominal max
+ offset: float phase offset in radians
+ func: function that maps phase to amplitude
+ """
+ self.freq = freq
+ self.amp = amp
+ self.offset = offset
+ self.func = func
+
+ @property
+ def period(self):
+ """Period of the signal in seconds.
+ returns: float seconds
+ """
+ return 1.0 / self.freq
+
+ def evaluate(self, ts):
+ """Evaluates the signal at the given times.
+ ts: float array of times
+ returns: float wave array
+ """
+ ts = np.asarray(ts)
+ phases = PI2 * self.freq * ts + self.offset
+ ys = self.amp * self.func(phases)
+ return ys
+
+
+def CosSignal(freq=440, amp=1.0, offset=0):
+ """Makes a cosine Sinusoid.
+ freq: float frequency in Hz
+ amp: float amplitude, 1.0 is nominal max
+ offset: float phase offset in radians
+ returns: Sinusoid object
+ """
+ return Sinusoid(freq, amp, offset, func=np.cos)
+
+
+def SinSignal(freq=440, amp=1.0, offset=0):
+ """Makes a sine Sinusoid.
+ freq: float frequency in Hz
+ amp: float amplitude, 1.0 is nominal max
+ offset: float phase offset in radians
+ returns: Sinusoid object
+ """
+ return Sinusoid(freq, amp, offset, func=np.sin)
+
+
+def Sinc(freq=440, amp=1.0, offset=0):
+ """Makes a Sinc function.
+ freq: float frequency in Hz
+ amp: float amplitude, 1.0 is nominal max
+ offset: float phase offset in radians
+ returns: Sinusoid object
+ """
+ return Sinusoid(freq, amp, offset, func=np.sinc)
+
+
+class ComplexSinusoid(Sinusoid):
+ """Represents a complex exponential signal."""
+
+ def evaluate(self, ts):
+ """Evaluates the signal at the given times.
+ ts: float array of times
+ returns: float wave array
+ """
+ ts = np.asarray(ts)
+ phases = PI2 * self.freq * ts + self.offset
+ ys = self.amp * np.exp(1j * phases)
+ return ys
+
+
+class SquareSignal(Sinusoid):
+ """Represents a square signal."""
+
+ def evaluate(self, ts):
+ """Evaluates the signal at the given times.
+ ts: float array of times
+ returns: float wave array
+ """
+ ts = np.asarray(ts)
+ cycles = self.freq * ts + self.offset / PI2
+ frac, _ = np.modf(cycles)
+ ys = self.amp * np.sign(unbias(frac))
+ return ys
+
+
+class SawtoothSignal(Sinusoid):
+ """Represents a sawtooth signal."""
+
+ def evaluate(self, ts):
+ """Evaluates the signal at the given times.
+ ts: float array of times
+ returns: float wave array
+ """
+ ts = np.asarray(ts)
+ cycles = self.freq * ts + self.offset / PI2
+ frac, _ = np.modf(cycles)
+ ys = normalize(unbias(frac), self.amp)
+ return ys
+
+
+class ParabolicSignal(Sinusoid):
+ """Represents a parabolic signal."""
+
+ def evaluate(self, ts):
+ """Evaluates the signal at the given times.
+ ts: float array of times
+ returns: float wave array
+ """
+ ts = np.asarray(ts)
+ cycles = self.freq * ts + self.offset / PI2
+ frac, _ = np.modf(cycles)
+ ys = (frac - 0.5) ** 2
+ ys = normalize(unbias(ys), self.amp)
+ return ys
+
+
+class CubicSignal(ParabolicSignal):
+ """Represents a cubic signal."""
+
+ def evaluate(self, ts):
+ """Evaluates the signal at the given times.
+ ts: float array of times
+ returns: float wave array
+ """
+ ys = ParabolicSignal.evaluate(self, ts)
+ ys = np.cumsum(ys)
+ ys = normalize(unbias(ys), self.amp)
+ return ys
+
+
+class GlottalSignal(Sinusoid):
+ """Represents a periodic signal that resembles a glottal signal."""
+
+ def evaluate(self, ts):
+ """Evaluates the signal at the given times.
+ ts: float array of times
+ returns: float wave array
+ """
+ ts = np.asarray(ts)
+ cycles = self.freq * ts + self.offset / PI2
+ frac, _ = np.modf(cycles)
+ ys = frac ** 2 * (1 - frac)
+ ys = normalize(unbias(ys), self.amp)
+ return ys
+
+
+class TriangleSignal(Sinusoid):
+ """Represents a triangle signal."""
+
+ def evaluate(self, ts):
+ """Evaluates the signal at the given times.
+ ts: float array of times
+ returns: float wave array
+ """
+ ts = np.asarray(ts)
+ cycles = self.freq * ts + self.offset / PI2
+ frac, _ = np.modf(cycles)
+ ys = np.abs(frac - 0.5)
+ ys = normalize(unbias(ys), self.amp)
+ return ys
+
+
+class Chirp(Signal):
+ """Represents a signal with variable frequency."""
+
+ def __init__(self, start=440, end=880, amp=1.0):
+ """Initializes a linear chirp.
+ start: float frequency in Hz
+ end: float frequency in Hz
+ amp: float amplitude, 1.0 is nominal max
+ """
+ self.start = start
+ self.end = end
+ self.amp = amp
+
+ @property
+ def period(self):
+ """Period of the signal in seconds.
+ returns: float seconds
+ """
+ return ValueError("Non-periodic signal.")
+
+ def evaluate(self, ts):
+ """Evaluates the signal at the given times.
+ ts: float array of times
+ returns: float wave array
+ """
+ freqs = np.linspace(self.start, self.end, len(ts) - 1)
+ return self._evaluate(ts, freqs)
+
+ def _evaluate(self, ts, freqs):
+ """Helper function that evaluates the signal.
+ ts: float array of times
+ freqs: float array of frequencies during each interval
+ """
+ dts = np.diff(ts)
+ dps = PI2 * freqs * dts
+ phases = np.cumsum(dps)
+ phases = np.insert(phases, 0, 0)
+ ys = self.amp * np.cos(phases)
+ return ys
+
+
+class ExpoChirp(Chirp):
+ """Represents a signal with varying frequency."""
+
+ def evaluate(self, ts):
+ """Evaluates the signal at the given times.
+ ts: float array of times
+ returns: float wave array
+ """
+ start, end = np.log10(self.start), np.log10(self.end)
+ freqs = np.logspace(start, end, len(ts) - 1)
+ return self._evaluate(ts, freqs)
+
+
+class SilentSignal(Signal):
+ """Represents silence."""
+
+ def evaluate(self, ts):
+ """Evaluates the signal at the given times.
+ ts: float array of times
+ returns: float wave array
+ """
+ return np.zeros(len(ts))
+
+
+class Impulses(Signal):
+ """Represents silence."""
+
+ def __init__(self, locations, amps=1):
+ self.locations = np.asanyarray(locations)
+ self.amps = amps
+
+ def evaluate(self, ts):
+ """Evaluates the signal at the given times.
+ ts: float array of times
+ returns: float wave array
+ """
+ ys = np.zeros(len(ts))
+ indices = np.searchsorted(ts, self.locations)
+ ys[indices] = self.amps
+ return ys
+
+
+class _Noise(Signal):
+ """Represents a noise signal (abstract parent class)."""
+
+ def __init__(self, amp=1.0):
+ """Initializes a white noise signal.
+ amp: float amplitude, 1.0 is nominal max
+ """
+ self.amp = amp
+
+ @property
+ def period(self):
+ """Period of the signal in seconds.
+ returns: float seconds
+ """
+ return ValueError("Non-periodic signal.")
+
+
+class UncorrelatedUniformNoise(_Noise):
+ """Represents uncorrelated uniform noise."""
+
+ def evaluate(self, ts):
+ """Evaluates the signal at the given times.
+ ts: float array of times
+ returns: float wave array
+ """
+ ys = np.random.uniform(-self.amp, self.amp, len(ts))
+ return ys
+
+
+class UncorrelatedGaussianNoise(_Noise):
+ """Represents uncorrelated gaussian noise."""
+
+ def evaluate(self, ts):
+ """Evaluates the signal at the given times.
+ ts: float array of times
+ returns: float wave array
+ """
+ ys = np.random.normal(0, self.amp, len(ts))
+ return ys
+
+
+class BrownianNoise(_Noise):
+ """Represents Brownian noise, aka red noise."""
+
+ def evaluate(self, ts):
+ """Evaluates the signal at the given times.
+ Computes Brownian noise by taking the cumulative sum of
+ a uniform random series.
+ ts: float array of times
+ returns: float wave array
+ """
+ dys = np.random.uniform(-1, 1, len(ts))
+ # ys = scipy.integrate.cumtrapz(dys, ts)
+ ys = np.cumsum(dys)
+ ys = normalize(unbias(ys), self.amp)
+ return ys
+
+
+class PinkNoise(_Noise):
+ """Represents Brownian noise, aka red noise."""
+
+ def __init__(self, amp=1.0, beta=1.0):
+ """Initializes a pink noise signal.
+ amp: float amplitude, 1.0 is nominal max
+ """
+ self.amp = amp
+ self.beta = beta
+
+ def make_wave(self, duration=1, start=0, framerate=11025):
+ """Makes a Wave object.
+ duration: float seconds
+ start: float seconds
+ framerate: int frames per second
+ returns: Wave
+ """
+ signal = UncorrelatedUniformNoise()
+ wave = signal.make_wave(duration, start, framerate)
+ spectrum = wave.make_spectrum()
+
+ spectrum.pink_filter(beta=self.beta)
+
+ wave2 = spectrum.make_wave()
+ wave2.unbias()
+ wave2.normalize(self.amp)
+ return wave2
+
+
+def rest(duration):
+ """Makes a rest of the given duration.
+ duration: float seconds
+ returns: Wave
+ """
+ signal = SilentSignal()
+ wave = signal.make_wave(duration)
+ return wave
+
+
+def make_note(midi_num, duration, sig_cons=CosSignal, framerate=11025):
+ """Make a MIDI note with the given duration.
+ midi_num: int MIDI note number
+ duration: float seconds
+ sig_cons: Signal constructor function
+ framerate: int frames per second
+ returns: Wave
+ """
+ freq = midi_to_freq(midi_num)
+ signal = sig_cons(freq)
+ wave = signal.make_wave(duration, framerate=framerate)
+ wave.apodize()
+ return wave
+
+
+def make_chord(midi_nums, duration, sig_cons=CosSignal, framerate=11025):
+ """Make a chord with the given duration.
+ midi_nums: sequence of int MIDI note numbers
+ duration: float seconds
+ sig_cons: Signal constructor function
+ framerate: int frames per second
+ returns: Wave
+ """
+ freqs = [midi_to_freq(num) for num in midi_nums]
+ signal = sum(sig_cons(freq) for freq in freqs)
+ wave = signal.make_wave(duration, framerate=framerate)
+ wave.apodize()
+ return wave
+
+
+def midi_to_freq(midi_num):
+ """Converts MIDI note number to frequency.
+ midi_num: int MIDI note number
+ returns: float frequency in Hz
+ """
+ x = (midi_num - 69) / 12.0
+ freq = 440.0 * 2 ** x
+ return freq
+
+
+def sin_wave(freq, duration=1, offset=0):
+ """Makes a sine wave with the given parameters.
+ freq: float cycles per second
+ duration: float seconds
+ offset: float radians
+ returns: Wave
+ """
+ signal = SinSignal(freq, offset=offset)
+ wave = signal.make_wave(duration)
+ return wave
+
+
+def cos_wave(freq, duration=1, offset=0):
+ """Makes a cosine wave with the given parameters.
+ freq: float cycles per second
+ duration: float seconds
+ offset: float radians
+ returns: Wave
+ """
+ signal = CosSignal(freq, offset=offset)
+ wave = signal.make_wave(duration)
+ return wave
+
+
+def mag(a):
+ """Computes the magnitude of a numpy array.
+ a: numpy array
+ returns: float
+ """
+ return np.sqrt(np.dot(a, a))
+
+
+def zero_pad(array, n):
+ """Extends an array with zeros.
+ array: numpy array
+ n: length of result
+ returns: new NumPy array
+ """
+ res = np.zeros(n)
+ res[: len(array)] = array
+ return res
+
+
+def main():
+
+ cos_basis = cos_wave(440)
+ sin_basis = sin_wave(440)
+
+ wave = cos_wave(440, offset=math.pi / 2)
+ cos_cov = cos_basis.cov(wave)
+ sin_cov = sin_basis.cov(wave)
+ print(cos_cov, sin_cov, mag((cos_cov, sin_cov)))
+ return
+
+ wfile = WavFileWriter()
+ for sig_cons in [
+ SinSignal,
+ TriangleSignal,
+ SawtoothSignal,
+ GlottalSignal,
+ ParabolicSignal,
+ SquareSignal,
+ ]:
+ print(sig_cons)
+ sig = sig_cons(440)
+ wave = sig.make_wave(1)
+ wave.apodize()
+ wfile.write(wave)
+ wfile.close()
+ return
+
+ signal = GlottalSignal(440)
+ signal.plot()
+ pyplot.show()
+ return
+
+ wfile = WavFileWriter()
+ for m in range(60, 0, -1):
+ wfile.write(make_note(m, 0.25))
+ wfile.close()
+ return
+
+ wave1 = make_note(69, 1)
+ wave2 = make_chord([69, 72, 76], 1)
+ wave = wave1 | wave2
+
+ wfile = WavFileWriter()
+ wfile.write(wave)
+ wfile.close()
+ return
+
+ sig1 = CosSignal(freq=440)
+ sig2 = CosSignal(freq=523.25)
+ sig3 = CosSignal(freq=660)
+ sig4 = CosSignal(freq=880)
+ sig5 = CosSignal(freq=987)
+ sig = sig1 + sig2 + sig3 + sig4
+
+ # wave = Wave(sig, duration=0.02)
+ # wave.plot()
+
+ wave = sig.make_wave(duration=1)
+ # wave.normalize()
+
+ wfile = WavFileWriter(wave)
+ wfile.write()
+ wfile.close()
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/thinkplot.py b/thinkplot.py
new file mode 100644
index 0000000..aceeffb
--- /dev/null
+++ b/thinkplot.py
@@ -0,0 +1,838 @@
+from __future__ import print_function
+
+import math
+import matplotlib
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+
+import warnings
+
+# customize some matplotlib attributes
+#matplotlib.rc('figure', figsize=(4, 3))
+
+#matplotlib.rc('font', size=14.0)
+#matplotlib.rc('axes', labelsize=22.0, titlesize=22.0)
+#matplotlib.rc('legend', fontsize=20.0)
+
+#matplotlib.rc('xtick.major', size=6.0)
+#matplotlib.rc('xtick.minor', size=3.0)
+
+#matplotlib.rc('ytick.major', size=6.0)
+#matplotlib.rc('ytick.minor', size=3.0)
+
+
+class _Brewer(object):
+ """Encapsulates a nice sequence of colors.
+ Shades of blue that look good in color and can be distinguished
+ in grayscale (up to a point).
+ Borrowed from http://colorbrewer2.org/
+ """
+ color_iter = None
+
+ colors = ['#f7fbff', '#deebf7', '#c6dbef',
+ '#9ecae1', '#6baed6', '#4292c6',
+ '#2171b5','#08519c','#08306b'][::-1]
+
+ # lists that indicate which colors to use depending on how many are used
+ which_colors = [[],
+ [1],
+ [1, 3],
+ [0, 2, 4],
+ [0, 2, 4, 6],
+ [0, 2, 3, 5, 6],
+ [0, 2, 3, 4, 5, 6],
+ [0, 1, 2, 3, 4, 5, 6],
+ [0, 1, 2, 3, 4, 5, 6, 7],
+ [0, 1, 2, 3, 4, 5, 6, 7, 8],
+ ]
+
+ current_figure = None
+
+ @classmethod
+ def Colors(cls):
+ """Returns the list of colors.
+ """
+ return cls.colors
+
+ @classmethod
+ def ColorGenerator(cls, num):
+ """Returns an iterator of color strings.
+ n: how many colors will be used
+ """
+ for i in cls.which_colors[num]:
+ yield cls.colors[i]
+ raise StopIteration('Ran out of colors in _Brewer.')
+
+ @classmethod
+ def InitIter(cls, num):
+ """Initializes the color iterator with the given number of colors."""
+ cls.color_iter = cls.ColorGenerator(num)
+ fig = plt.gcf()
+ cls.current_figure = fig
+
+ @classmethod
+ def ClearIter(cls):
+ """Sets the color iterator to None."""
+ cls.color_iter = None
+ cls.current_figure = None
+
+ @classmethod
+ def GetIter(cls, num):
+ """Gets the color iterator."""
+ fig = plt.gcf()
+ if fig != cls.current_figure:
+ cls.InitIter(num)
+ cls.current_figure = fig
+
+ if cls.color_iter is None:
+ cls.InitIter(num)
+
+ return cls.color_iter
+
+
+def _UnderrideColor(options):
+ """If color is not in the options, chooses a color.
+ """
+ if 'color' in options:
+ return options
+
+ # get the current color iterator; if there is none, init one
+ color_iter = _Brewer.GetIter(5)
+
+ try:
+ options['color'] = next(color_iter)
+ except StopIteration:
+ # if you run out of colors, initialize the color iterator
+ # and try again
+ warnings.warn('Ran out of colors. Starting over.')
+ _Brewer.ClearIter()
+ _UnderrideColor(options)
+
+ return options
+
+
+def PrePlot(num=None, rows=None, cols=None):
+ """Takes hints about what's coming.
+ num: number of lines that will be plotted
+ rows: number of rows of subplots
+ cols: number of columns of subplots
+ """
+ if num:
+ _Brewer.InitIter(num)
+
+ if rows is None and cols is None:
+ return
+
+ if rows is not None and cols is None:
+ cols = 1
+
+ if cols is not None and rows is None:
+ rows = 1
+
+ # resize the image, depending on the number of rows and cols
+ size_map = {(1, 1): (8, 6),
+ (1, 2): (12, 6),
+ (1, 3): (12, 6),
+ (1, 4): (12, 5),
+ (1, 5): (12, 4),
+ (2, 2): (10, 10),
+ (2, 3): (16, 10),
+ (3, 1): (8, 10),
+ (4, 1): (8, 12),
+ }
+
+ if (rows, cols) in size_map:
+ fig = plt.gcf()
+ fig.set_size_inches(*size_map[rows, cols])
+
+ # create the first subplot
+ if rows > 1 or cols > 1:
+ ax = plt.subplot(rows, cols, 1)
+ global SUBPLOT_ROWS, SUBPLOT_COLS
+ SUBPLOT_ROWS = rows
+ SUBPLOT_COLS = cols
+ else:
+ ax = plt.gca()
+
+ return ax
+
+
+def SubPlot(plot_number, rows=None, cols=None, **options):
+ """Configures the number of subplots and changes the current plot.
+ rows: int
+ cols: int
+ plot_number: int
+ options: passed to subplot
+ """
+ rows = rows or SUBPLOT_ROWS
+ cols = cols or SUBPLOT_COLS
+ return plt.subplot(rows, cols, plot_number, **options)
+
+
+def _Underride(d, **options):
+ """Add key-value pairs to d only if key is not in d.
+ If d is None, create a new dictionary.
+ d: dictionary
+ options: keyword args to add to d
+ """
+ if d is None:
+ d = {}
+
+ for key, val in options.items():
+ d.setdefault(key, val)
+
+ return d
+
+
+def Clf():
+ """Clears the figure and any hints that have been set."""
+ global LOC
+ LOC = None
+ _Brewer.ClearIter()
+ plt.clf()
+ fig = plt.gcf()
+ fig.set_size_inches(8, 6)
+
+
+def Figure(**options):
+ """Sets options for the current figure."""
+ _Underride(options, figsize=(6, 8))
+ plt.figure(**options)
+
+
+def Plot(obj, ys=None, style='', **options):
+ """Plots a line.
+ Args:
+ obj: sequence of x values, or Series, or anything with Render()
+ ys: sequence of y values
+ style: style string passed along to plt.plot
+ options: keyword args passed to plt.plot
+ """
+ options = _UnderrideColor(options)
+ label = getattr(obj, 'label', '_nolegend_')
+ options = _Underride(options, linewidth=3, alpha=0.7, label=label)
+
+ xs = obj
+ if ys is None:
+ if hasattr(obj, 'Render'):
+ xs, ys = obj.Render()
+ if isinstance(obj, pd.Series):
+ ys = obj.values
+ xs = obj.index
+
+ if ys is None:
+ plt.plot(xs, style, **options)
+ else:
+ plt.plot(xs, ys, style, **options)
+
+
+def Vlines(xs, y1, y2, **options):
+ """Plots a set of vertical lines.
+ Args:
+ xs: sequence of x values
+ y1: sequence of y values
+ y2: sequence of y values
+ options: keyword args passed to plt.vlines
+ """
+ options = _UnderrideColor(options)
+ options = _Underride(options, linewidth=1, alpha=0.5)
+ plt.vlines(xs, y1, y2, **options)
+
+
+def Hlines(ys, x1, x2, **options):
+ """Plots a set of horizontal lines.
+ Args:
+ ys: sequence of y values
+ x1: sequence of x values
+ x2: sequence of x values
+ options: keyword args passed to plt.vlines
+ """
+ options = _UnderrideColor(options)
+ options = _Underride(options, linewidth=1, alpha=0.5)
+ plt.hlines(ys, x1, x2, **options)
+
+
+def axvline(x, **options):
+ """Plots a vertical line.
+ Args:
+ x: x location
+ options: keyword args passed to plt.axvline
+ """
+ options = _UnderrideColor(options)
+ options = _Underride(options, linewidth=1, alpha=0.5)
+ plt.axvline(x, **options)
+
+
+def axhline(y, **options):
+ """Plots a horizontal line.
+ Args:
+ y: y location
+ options: keyword args passed to plt.axhline
+ """
+ options = _UnderrideColor(options)
+ options = _Underride(options, linewidth=1, alpha=0.5)
+ plt.axhline(y, **options)
+
+
+def tight_layout(**options):
+ """Adjust subplots to minimize padding and margins.
+ """
+ options = _Underride(options,
+ wspace=0.1, hspace=0.1,
+ left=0, right=1,
+ bottom=0, top=1)
+ plt.tight_layout()
+ plt.subplots_adjust(**options)
+
+
+def FillBetween(xs, y1, y2=None, where=None, **options):
+ """Fills the space between two lines.
+ Args:
+ xs: sequence of x values
+ y1: sequence of y values
+ y2: sequence of y values
+ where: sequence of boolean
+ options: keyword args passed to plt.fill_between
+ """
+ options = _UnderrideColor(options)
+ options = _Underride(options, linewidth=0, alpha=0.5)
+ plt.fill_between(xs, y1, y2, where, **options)
+
+
+def Bar(xs, ys, **options):
+ """Plots a line.
+ Args:
+ xs: sequence of x values
+ ys: sequence of y values
+ options: keyword args passed to plt.bar
+ """
+ options = _UnderrideColor(options)
+ options = _Underride(options, linewidth=0, alpha=0.6)
+ plt.bar(xs, ys, **options)
+
+
+def Scatter(xs, ys=None, **options):
+ """Makes a scatter plot.
+ xs: x values
+ ys: y values
+ options: options passed to plt.scatter
+ """
+ options = _Underride(options, color='blue', alpha=0.2,
+ s=30, edgecolors='none')
+
+ if ys is None and isinstance(xs, pd.Series):
+ ys = xs.values
+ xs = xs.index
+
+ plt.scatter(xs, ys, **options)
+
+
+def HexBin(xs, ys, **options):
+ """Makes a scatter plot.
+ xs: x values
+ ys: y values
+ options: options passed to plt.scatter
+ """
+ options = _Underride(options, cmap=matplotlib.cm.Blues)
+ plt.hexbin(xs, ys, **options)
+
+
+def Pdf(pdf, **options):
+ """Plots a Pdf, Pmf, or Hist as a line.
+ Args:
+ pdf: Pdf, Pmf, or Hist object
+ options: keyword args passed to plt.plot
+ """
+ low, high = options.pop('low', None), options.pop('high', None)
+ n = options.pop('n', 101)
+ xs, ps = pdf.Render(low=low, high=high, n=n)
+ options = _Underride(options, label=pdf.label)
+ Plot(xs, ps, **options)
+
+
+def Pdfs(pdfs, **options):
+ """Plots a sequence of PDFs.
+ Options are passed along for all PDFs. If you want different
+ options for each pdf, make multiple calls to Pdf.
+ Args:
+ pdfs: sequence of PDF objects
+ options: keyword args passed to plt.plot
+ """
+ for pdf in pdfs:
+ Pdf(pdf, **options)
+
+
+def Hist(hist, **options):
+ """Plots a Pmf or Hist with a bar plot.
+ The default width of the bars is based on the minimum difference
+ between values in the Hist. If that's too small, you can override
+ it by providing a width keyword argument, in the same units
+ as the values.
+ Args:
+ hist: Hist or Pmf object
+ options: keyword args passed to plt.bar
+ """
+ # find the minimum distance between adjacent values
+ xs, ys = hist.Render()
+
+ # see if the values support arithmetic
+ try:
+ xs[0] - xs[0]
+ except TypeError:
+ # if not, replace values with numbers
+ labels = [str(x) for x in xs]
+ xs = np.arange(len(xs))
+ plt.xticks(xs+0.5, labels)
+
+ if 'width' not in options:
+ try:
+ options['width'] = 0.9 * np.diff(xs).min()
+ except TypeError:
+ warnings.warn("Hist: Can't compute bar width automatically."
+ "Check for non-numeric types in Hist."
+ "Or try providing width option."
+ )
+
+ options = _Underride(options, label=hist.label)
+ options = _Underride(options, align='center')
+ if options['align'] == 'left':
+ options['align'] = 'edge'
+ elif options['align'] == 'right':
+ options['align'] = 'edge'
+ options['width'] *= -1
+
+ Bar(xs, ys, **options)
+
+
+def Hists(hists, **options):
+ """Plots two histograms as interleaved bar plots.
+ Options are passed along for all PMFs. If you want different
+ options for each pmf, make multiple calls to Pmf.
+ Args:
+ hists: list of two Hist or Pmf objects
+ options: keyword args passed to plt.plot
+ """
+ for hist in hists:
+ Hist(hist, **options)
+
+
+def Pmf(pmf, **options):
+ """Plots a Pmf or Hist as a line.
+ Args:
+ pmf: Hist or Pmf object
+ options: keyword args passed to plt.plot
+ """
+ xs, ys = pmf.Render()
+ low, high = min(xs), max(xs)
+
+ width = options.pop('width', None)
+ if width is None:
+ try:
+ width = np.diff(xs).min()
+ except TypeError:
+ warnings.warn("Pmf: Can't compute bar width automatically."
+ "Check for non-numeric types in Pmf."
+ "Or try providing width option.")
+ points = []
+
+ lastx = np.nan
+ lasty = 0
+ for x, y in zip(xs, ys):
+ if (x - lastx) > 1e-5:
+ points.append((lastx, 0))
+ points.append((x, 0))
+
+ points.append((x, lasty))
+ points.append((x, y))
+ points.append((x+width, y))
+
+ lastx = x + width
+ lasty = y
+ points.append((lastx, 0))
+ pxs, pys = zip(*points)
+
+ align = options.pop('align', 'center')
+ if align == 'center':
+ pxs = np.array(pxs) - width/2.0
+ if align == 'right':
+ pxs = np.array(pxs) - width
+
+ options = _Underride(options, label=pmf.label)
+ Plot(pxs, pys, **options)
+
+
+def Pmfs(pmfs, **options):
+ """Plots a sequence of PMFs.
+ Options are passed along for all PMFs. If you want different
+ options for each pmf, make multiple calls to Pmf.
+ Args:
+ pmfs: sequence of PMF objects
+ options: keyword args passed to plt.plot
+ """
+ for pmf in pmfs:
+ Pmf(pmf, **options)
+
+
+def Diff(t):
+ """Compute the differences between adjacent elements in a sequence.
+ Args:
+ t: sequence of number
+ Returns:
+ sequence of differences (length one less than t)
+ """
+ diffs = [t[i+1] - t[i] for i in range(len(t)-1)]
+ return diffs
+
+
+def Cdf(cdf, complement=False, transform=None, **options):
+ """Plots a CDF as a line.
+ Args:
+ cdf: Cdf object
+ complement: boolean, whether to plot the complementary CDF
+ transform: string, one of 'exponential', 'pareto', 'weibull', 'gumbel'
+ options: keyword args passed to plt.plot
+ Returns:
+ dictionary with the scale options that should be passed to
+ Config, Show or Save.
+ """
+ xs, ps = cdf.Render()
+ xs = np.asarray(xs)
+ ps = np.asarray(ps)
+
+ scale = dict(xscale='linear', yscale='linear')
+
+ for s in ['xscale', 'yscale']:
+ if s in options:
+ scale[s] = options.pop(s)
+
+ if transform == 'exponential':
+ complement = True
+ scale['yscale'] = 'log'
+
+ if transform == 'pareto':
+ complement = True
+ scale['yscale'] = 'log'
+ scale['xscale'] = 'log'
+
+ if complement:
+ ps = [1.0-p for p in ps]
+
+ if transform == 'weibull':
+ xs = np.delete(xs, -1)
+ ps = np.delete(ps, -1)
+ ps = [-math.log(1.0-p) for p in ps]
+ scale['xscale'] = 'log'
+ scale['yscale'] = 'log'
+
+ if transform == 'gumbel':
+ xs = np.delete(xs, 0)
+ ps = np.delete(ps, 0)
+ ps = [-math.log(p) for p in ps]
+ scale['yscale'] = 'log'
+
+ options = _Underride(options, label=cdf.label)
+ Plot(xs, ps, **options)
+ return scale
+
+
+def Cdfs(cdfs, complement=False, transform=None, **options):
+ """Plots a sequence of CDFs.
+ cdfs: sequence of CDF objects
+ complement: boolean, whether to plot the complementary CDF
+ transform: string, one of 'exponential', 'pareto', 'weibull', 'gumbel'
+ options: keyword args passed to plt.plot
+ """
+ for cdf in cdfs:
+ Cdf(cdf, complement, transform, **options)
+
+
+def Contour(obj, pcolor=False, contour=True, imshow=False, **options):
+ """Makes a contour plot.
+ d: map from (x, y) to z, or object that provides GetDict
+ pcolor: boolean, whether to make a pseudocolor plot
+ contour: boolean, whether to make a contour plot
+ imshow: boolean, whether to use plt.imshow
+ options: keyword args passed to plt.pcolor and/or plt.contour
+ """
+ try:
+ d = obj.GetDict()
+ except AttributeError:
+ d = obj
+
+ _Underride(options, linewidth=3, cmap=matplotlib.cm.Blues)
+
+ xs, ys = zip(*d.keys())
+ xs = sorted(set(xs))
+ ys = sorted(set(ys))
+
+ X, Y = np.meshgrid(xs, ys)
+ func = lambda x, y: d.get((x, y), 0)
+ func = np.vectorize(func)
+ Z = func(X, Y)
+
+ x_formatter = matplotlib.ticker.ScalarFormatter(useOffset=False)
+ axes = plt.gca()
+ axes.xaxis.set_major_formatter(x_formatter)
+
+ if pcolor:
+ plt.pcolormesh(X, Y, Z, **options)
+ if contour:
+ cs = plt.contour(X, Y, Z, **options)
+ plt.clabel(cs, inline=1, fontsize=10)
+ if imshow:
+ extent = xs[0], xs[-1], ys[0], ys[-1]
+ plt.imshow(Z, extent=extent, **options)
+
+
+def Pcolor(xs, ys, zs, pcolor=True, contour=False, **options):
+ """Makes a pseudocolor plot.
+ xs:
+ ys:
+ zs:
+ pcolor: boolean, whether to make a pseudocolor plot
+ contour: boolean, whether to make a contour plot
+ options: keyword args passed to plt.pcolor and/or plt.contour
+ """
+ _Underride(options, linewidth=3, cmap=matplotlib.cm.Blues)
+
+ X, Y = np.meshgrid(xs, ys)
+ Z = zs
+
+ x_formatter = matplotlib.ticker.ScalarFormatter(useOffset=False)
+ axes = plt.gca()
+ axes.xaxis.set_major_formatter(x_formatter)
+
+ if pcolor:
+ plt.pcolormesh(X, Y, Z, **options)
+
+ if contour:
+ cs = plt.contour(X, Y, Z, **options)
+ plt.clabel(cs, inline=1, fontsize=10)
+
+
+def Text(x, y, s, **options):
+ """Puts text in a figure.
+ x: number
+ y: number
+ s: string
+ options: keyword args passed to plt.text
+ """
+ options = _Underride(options,
+ fontsize=16,
+ verticalalignment='top',
+ horizontalalignment='left')
+ plt.text(x, y, s, **options)
+
+
+LEGEND = True
+LOC = None
+
+def Config(**options):
+ """Configures the plot.
+ Pulls options out of the option dictionary and passes them to
+ the corresponding plt functions.
+ """
+ names = ['title', 'xlabel', 'ylabel', 'xscale', 'yscale',
+ 'xticks', 'yticks', 'axis', 'xlim', 'ylim']
+
+ for name in names:
+ if name in options:
+ getattr(plt, name)(options[name])
+
+ global LEGEND
+ LEGEND = options.get('legend', LEGEND)
+
+ # see if there are any elements with labels;
+ # if not, don't draw a legend
+ ax = plt.gca()
+ handles, labels = ax.get_legend_handles_labels()
+
+ if LEGEND and len(labels) > 0:
+ global LOC
+ LOC = options.get('loc', LOC)
+ frameon = options.get('frameon', True)
+
+ try:
+ plt.legend(loc=LOC, frameon=frameon)
+ except UserWarning:
+ pass
+
+ # x and y ticklabels can be made invisible
+ val = options.get('xticklabels', None)
+ if val is not None:
+ if val == 'invisible':
+ ax = plt.gca()
+ labels = ax.get_xticklabels()
+ plt.setp(labels, visible=False)
+
+ val = options.get('yticklabels', None)
+ if val is not None:
+ if val == 'invisible':
+ ax = plt.gca()
+ labels = ax.get_yticklabels()
+ plt.setp(labels, visible=False)
+
+def set_font_size(title_size=16, label_size=16, ticklabel_size=14, legend_size=14):
+ """Set font sizes for the title, labels, ticklabels, and legend.
+ """
+ def set_text_size(texts, size):
+ for text in texts:
+ text.set_size(size)
+
+ ax = plt.gca()
+
+ # TODO: Make this function more robust if any of these elements
+ # is missing.
+
+ # title
+ ax.title.set_size(title_size)
+
+ # x axis
+ ax.xaxis.label.set_size(label_size)
+ set_text_size(ax.xaxis.get_ticklabels(), ticklabel_size)
+
+ # y axis
+ ax.yaxis.label.set_size(label_size)
+ set_text_size(ax.yaxis.get_ticklabels(), ticklabel_size)
+
+ # legend
+ legend = ax.get_legend()
+ if legend is not None:
+ set_text_size(legend.texts, legend_size)
+
+
+def bigger_text():
+ sizes = dict(title_size=16, label_size=16, ticklabel_size=14, legend_size=14)
+ set_font_size(**sizes)
+
+
+def Show(**options):
+ """Shows the plot.
+ For options, see Config.
+ options: keyword args used to invoke various plt functions
+ """
+ clf = options.pop('clf', True)
+ Config(**options)
+ plt.show()
+ if clf:
+ Clf()
+
+
+def Plotly(**options):
+ """Shows the plot.
+ For options, see Config.
+ options: keyword args used to invoke various plt functions
+ """
+ clf = options.pop('clf', True)
+ Config(**options)
+ import plotly.plotly as plotly
+ url = plotly.plot_mpl(plt.gcf())
+ if clf:
+ Clf()
+ return url
+
+
+def Save(root=None, formats=None, **options):
+ """Saves the plot in the given formats and clears the figure.
+ For options, see Config.
+ Note: With a capital S, this is the original save, maintained for
+ compatibility. New code should use save(), which works better
+ with my newer code, especially in Jupyter notebooks.
+ Args:
+ root: string filename root
+ formats: list of string formats
+ options: keyword args used to invoke various plt functions
+ """
+ clf = options.pop('clf', True)
+
+ save_options = {}
+ for option in ['bbox_inches', 'pad_inches']:
+ if option in options:
+ save_options[option] = options.pop(option)
+
+ # TODO: falling Config inside Save was probably a mistake, but removing
+ # it will require some work
+ Config(**options)
+
+ if formats is None:
+ formats = ['pdf', 'png']
+
+ try:
+ formats.remove('plotly')
+ Plotly(clf=False)
+ except ValueError:
+ pass
+
+ if root:
+ for fmt in formats:
+ SaveFormat(root, fmt, **save_options)
+ if clf:
+ Clf()
+
+
+def save(root, formats=None, **options):
+ """Saves the plot in the given formats and clears the figure.
+ For options, see plt.savefig.
+ Args:
+ root: string filename root
+ formats: list of string formats
+ options: keyword args passed to plt.savefig
+ """
+ if formats is None:
+ formats = ['pdf', 'png']
+
+ try:
+ formats.remove('plotly')
+ Plotly(clf=False)
+ except ValueError:
+ pass
+
+ for fmt in formats:
+ SaveFormat(root, fmt, **options)
+
+
+def SaveFormat(root, fmt='eps', **options):
+ """Writes the current figure to a file in the given format.
+ Args:
+ root: string filename root
+ fmt: string format
+ """
+ _Underride(options, dpi=300)
+ filename = '%s.%s' % (root, fmt)
+ print('Writing', filename)
+ plt.savefig(filename, format=fmt, **options)
+
+
+# provide aliases for calling functions with lower-case names
+preplot = PrePlot
+subplot = SubPlot
+clf = Clf
+figure = Figure
+plot = Plot
+vlines = Vlines
+hlines = Hlines
+fill_between = FillBetween
+text = Text
+scatter = Scatter
+pmf = Pmf
+pmfs = Pmfs
+hist = Hist
+hists = Hists
+diff = Diff
+cdf = Cdf
+cdfs = Cdfs
+contour = Contour
+pcolor = Pcolor
+config = Config
+show = Show
+
+
+def main():
+ color_iter = _Brewer.ColorGenerator(7)
+ for color in color_iter:
+ print(color)
+
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/total.py b/total.py
new file mode 100644
index 0000000..136967f
--- /dev/null
+++ b/total.py
@@ -0,0 +1,339 @@
+import os
+import glob
+import shutil
+
+from PIL import Image
+import random
+
+from lxml import etree
+
+# 所有路径以'/'结尾
+your_voc_path = 'C:/Users/c9347/Desktop/voc/' # voc数据集路径
+yolo_path = 'D:/labels/positive/' # 原始数据集路径
+yolo_filtered_path = 'D:/labes_29/' # 过滤后的yolo数据集路径
+# trainval_percent = 0.9
+# train_percent = 0.9
+threshold = 200 # 图片数量阈值,选取图片数量大于350的类别
+# threshold_cls = 15 # 类别阈值,选取最多的前15类数据
+
+labels = ['road roller', 'bar deposits', 'piece deposits', 'brick', 'earth vehicles', 'tower', 'digger', 'bulldozer',
+ 'drill', 'crane', 'concrete truck', 'mixer', 'concrete simple house', 'simple house', 'green cover',
+ 'black cover', 'blue enclosure', 'grey enclosure', 'color enclosure', 'building', 'groove', 'big building',
+ 'building frame', 'scaffold', 'vehicle', 'grave mound', 'garbage', 'crushed stones', 'bricks', 'greenhouse',
+ 'site shanty', 'woodpile', 'fuel tank', 'big truck', 'car', 'boxcar', 'small truck', 'van car',
+ 'watering car', 'tutu', 'crane closed', 'Agricultural Tricycles', 'bus', 'pickup', 'large cement pipes',
+ 'middle cement pipes', 'small cement pipes', 'thin steel pipe', 'crude steel pipe', 'big stell pipe', 'slab',
+ 'U-steel', 'road leveling machine']
+# 建立所需文件夹
+# os.path.exists(path)——检验指定的对象是否存在。是True,否则False.
+# os.makedirs(path[, mode]) 递归文件夹创建函数。
+if not os.path.exists(your_voc_path + 'VOCdevkit/VOC2007/Annotations/'):
+ os.makedirs(your_voc_path + 'VOCdevkit/VOC2007/Annotations/')
+if not os.path.exists(your_voc_path + 'VOCdevkit/VOC2007/ImageSets/Layout'):
+ os.makedirs(your_voc_path + 'VOCdevkit/VOC2007/ImageSets/Layout')
+if not os.path.exists(your_voc_path + 'VOCdevkit/VOC2007/ImageSets/Main'):
+ os.makedirs(your_voc_path + 'VOCdevkit/VOC2007/ImageSets/Main')
+if not os.path.exists(your_voc_path + 'VOCdevkit/VOC2007/ImageSets/Segmentation'):
+ os.makedirs(your_voc_path + 'VOCdevkit/VOC2007/ImageSets/Segmentation')
+if not os.path.exists(your_voc_path + 'VOCdevkit/VOC2007/JPEGImages'):
+ os.makedirs(your_voc_path + 'VOCdevkit/VOC2007/JPEGImages')
+if not os.path.exists(your_voc_path + 'VOCdevkit/VOC2007/labels'):
+ os.makedirs(your_voc_path + 'VOCdevkit/VOC2007/labels')
+if not os.path.exists(your_voc_path + 'VOCdevkit_filtered/VOC2007/Annotations/'):
+ os.makedirs(your_voc_path + 'VOCdevkit_filtered/VOC2007/Annotations/')
+if not os.path.exists(your_voc_path + 'VOCdevkit_filtered/VOC2007/ImageSets/Layout'):
+ os.makedirs(your_voc_path + 'VOCdevkit_filtered/VOC2007/ImageSets/Layout')
+if not os.path.exists(your_voc_path + 'VOCdevkit_filtered/VOC2007/ImageSets/Main'):
+ os.makedirs(your_voc_path + 'VOCdevkit_filtered/VOC2007/ImageSets/Main')
+if not os.path.exists(your_voc_path + 'VOCdevkit_filtered/VOC2007/ImageSets/Segmentation'):
+ os.makedirs(your_voc_path + 'VOCdevkit_filtered/VOC2007/ImageSets/Segmentation')
+if not os.path.exists(your_voc_path + 'VOCdevkit_filtered/VOC2007/JPEGImages'):
+ os.makedirs(your_voc_path + 'VOCdevkit_filtered/VOC2007/JPEGImages')
+if not os.path.exists(your_voc_path + 'VOCdevkit_filtered/VOC2007/labels'):
+ os.makedirs(your_voc_path + 'VOCdevkit_filtered/VOC2007/labels')
+if not os.path.exists(yolo_filtered_path):
+ os.makedirs(yolo_filtered_path)
+# 数据重命名
+# os.listdir(path)——列出path目录下所有的文件和目录名。
+listdir = os.listdir(yolo_path)
+count = 0
+for i, file in enumerate(listdir):
+ if i % 100 == 0:
+ print(i)
+ # os.path.splitext(path)
+ # 分离文件名与扩展名;默认返回(fname,fextension)元组,可做分片操作>>> os.path.splitext('c:\\csv\\test.csv')
+ # ('c:\\csv\\test', '.csv')
+ filename = os.path.splitext(file)[0] # 文件名
+ filetype = os.path.splitext(file)[1] # 文件扩展名
+ if filetype == '.txt':
+ continue
+ # os.path.join(path1[, path2[, ...]])
+ # 将多个路径组合后返回,第一个绝对路径之前的参数将被忽略。>>> os.path.join('c:\\', 'csv', 'test.csv')
+ # 'c:\\csv\\test.csv'
+
+ Olddir = os.path.join(yolo_path, file)
+ Newdir = os.path.join(yolo_path, str(count).zfill(6) + '.jpg')
+ Oldanno = os.path.join(yolo_path, filename + '.txt')
+ Newanno = os.path.join(yolo_path, str(count).zfill(6) + '.txt')
+ # os.rename(src, dst) 重命名文件或目录,从 src 到 dst
+ os.rename(Olddir, Newdir)
+ os.rename(Oldanno, Newanno)
+ shutil.copyfile(Newdir, your_voc_path + 'VOCdevkit/VOC2007/JPEGImages/' + str(count).zfill(6) + '.jpg')
+ count += 1
+
+# 生成voc格式数据集
+voc_xml = your_voc_path + 'VOCdevkit/VOC2007/Annotations/'
+
+# 匹配文件路径下的所有jpg文件,并返回列表
+img_glob = glob.glob(yolo_path + '*.jpg')
+
+img_base_names = []
+
+for img in img_glob:
+ # os.path.basename:取文件的后缀名
+ img_base_names.append(os.path.basename(img))
+
+img_pre_name = []
+
+for img in img_base_names:
+ # os.path.splitext:将文件按照后缀切分为两块
+ temp1, temp2 = os.path.splitext(img)
+ img_pre_name.append(temp1)
+ print(f'imgpre:{len(img_pre_name)}')
+for i, img in enumerate(img_pre_name):
+ if i % 100 == 0:
+ print(i)
+ with open(voc_xml + img + '.xml', 'w') as xml_files:
+ image = Image.open(yolo_path + img + '.jpg')
+ img_w, img_h = image.size
+ xml_files.write('\n')
+ xml_files.write(' folder\n')
+ xml_files.write(f' {img}.jpg\n')
+ xml_files.write(' \n')
+ xml_files.write(' Unknown\n')
+ xml_files.write(' \n')
+ xml_files.write(' \n')
+ xml_files.write(f' {img_w}\n')
+ xml_files.write(f' {img_h}\n')
+ xml_files.write(f' 3\n')
+ xml_files.write(' \n')
+ xml_files.write(' 0\n')
+ with open(yolo_path + img + '.txt', 'r') as f:
+ # 以列表形式返回每一行
+ lines = f.read().splitlines()
+ for each_line in lines:
+ line = each_line.split(' ')
+ xml_files.write(' \n')
+ xml_files.write('')
+
+# # 划分数据集
+# xmlfilepath = your_voc_path + 'VOCdevkit/VOC2007/Annotations/'
+# txtsavepath = your_voc_path + 'VOCdevkit/VOC2007/ImgSets/Main/'
+# total_xml = os.listdir(xmlfilepath)
+#
+# num = len(total_xml)
+# list = range(num)
+# tv = int(num * trainval_percent)
+# tr = int(tv * train_percent)
+# trainval = random.sample(list, tv)
+# train = random.sample(trainval, tr)
+#
+# ftrainval = open(your_voc_path + 'VOCdevkit/VOC2007/ImageSets/Main/trainval.txt', 'w')
+# ftest = open(your_voc_path + 'VOCdevkit/VOC2007/ImageSets/Main/test.txt', 'w')
+# ftrain = open(your_voc_path + 'VOCdevkit/VOC2007/ImageSets/Main/train.txt', 'w')
+# fval = open(your_voc_path + 'VOCdevkit/VOC2007/ImageSets/Main/val.txt', 'w')
+#
+# for i in list:
+# name = total_xml[i][:-4] + '\n'
+# if i in trainval:
+# ftrainval.write(name)
+# if i in train:
+# ftrain.write(name)
+# else:
+# fval.write(name)
+# else:
+# ftest.write(name)
+#
+# ftrainval.close()
+# ftrain.close()
+# fval.close()
+# ftest.close()
+
+# 数据筛选
+print('根据阈值计算筛选后的类别')
+path = your_voc_path + 'VOCdevkit/VOC2007/Annotations/'
+listdir = os.listdir(path)
+count = [0 for i in range(53)] # 每个类别的图片数
+for file in listdir:
+ with open(path + file, "r") as f:
+ text = f.read()
+ text = etree.fromstring(text)
+ name = text.xpath('/annotation/object/name/text()')
+ if name[0] not in ['tutu', 'car', 'garbage', 'van car', 'Agricultural Tricycles', 'pickup']:
+ count[int(labels.index(name[0]))] += 1
+new_labels = [] # 筛选后的类别标签
+new_cls = [] # 筛选后的类别编号
+
+
+# 根据类别阈值筛选
+def maxk(arraylist, k): # 返回最大的前k个数据的索引,k为类别阈值
+ maxlist = []
+ maxlist_id = [i for i in range(0, k)]
+ m = [maxlist, maxlist_id]
+ for i in maxlist_id:
+ maxlist.append(arraylist[i])
+ for i in range(k, len(arraylist)): # 对目标数组之后的数字
+ if arraylist[i] > min(maxlist):
+ mm = maxlist.index(min(maxlist))
+ del m[0][mm]
+ del m[1][mm]
+ m[0].append(arraylist[i])
+ m[1].append(i)
+ return maxlist_id
+
+
+# kmax_list = maxk(count, threshold_cls)
+# for i in kmax_list:
+# new_cls.append(i)
+# new_labels.append(labels[i])
+# 根据图片数量阈值筛选
+for index, i in enumerate(count):
+ if int(i) > threshold:
+ new_cls.append(index)
+ new_labels.append(labels[index])
+
+img_glob = glob.glob(yolo_path + '*.jpg')
+img_base_names = []
+for img in img_glob:
+ # os.path.basename:取文件的后缀名
+ img_base_names.append(os.path.basename(img))
+print('开始筛选数据')
+img_pre_name = []
+count = 0
+for img in img_base_names:
+ # os.path.splitext:将文件按照后缀切分为两块
+ temp1, temp2 = os.path.splitext(img)
+ img_pre_name.append(temp1)
+
+print('清空筛选后的文件夹')
+voc_filter_xml = your_voc_path + 'VOCdevkit_filtered/VOC2007/Annotations/'
+listxml = os.listdir(voc_filter_xml)
+# 清空筛选后的文件夹
+for file in listxml:
+ os.remove(voc_filter_xml + file)
+voc_filter_img = your_voc_path + 'VOCdevkit_filtered/VOC2007/JPEGImages/'
+listimg = os.listdir(voc_filter_img)
+for file in listimg:
+ os.remove(voc_filter_img + file)
+print('写入筛选后的数据')
+for i, img in enumerate(img_pre_name):
+ if i % 100 == 0:
+ print(i)
+ with open(yolo_path + img + '.txt', 'r') as f:
+ # 以列表形式返回每一行
+ lines = f.read().splitlines()
+ line = lines[0].split(' ')
+ if int(line[0]) in new_cls:
+ # 生成筛选后的的yolo格式数据集
+ newcls = new_cls.index(int(line[0]))
+ newanno_txt = line
+ newanno_txt[0] = str(newcls)
+ newtxt = ' '.join(newanno_txt)
+ with open(yolo_filtered_path + str(count).zfill(6) + '.txt', 'w') as f:
+ f.write(newtxt)
+ shutil.copyfile(yolo_path + img + '.jpg', yolo_filtered_path + str(count).zfill(6) + '.jpg')
+ # 生成筛选后的的xml格式数据集
+ shutil.copyfile(yolo_path + img + '.jpg',
+ voc_filter_img + str(count).zfill(6) + '.jpg')
+ with open(voc_filter_xml + str(count).zfill(6) + '.xml', 'w') as xml_files:
+ image = Image.open(yolo_path + img + '.jpg')
+ img_w, img_h = image.size
+ xml_files.write('\n')
+ xml_files.write(' folder\n')
+ xml_files.write(f' {img}.jpg\n')
+ xml_files.write(' \n')
+ xml_files.write(' Unknown\n')
+ xml_files.write(' \n')
+ xml_files.write(' \n')
+ xml_files.write(f' {img_w}\n')
+ xml_files.write(f' {img_h}\n')
+ xml_files.write(f' 3\n')
+ xml_files.write(' \n')
+ xml_files.write(' 0\n')
+ for each_line in lines:
+ line = each_line.split(' ')
+ xml_files.write(' \n')
+ xml_files.write('')
+ count += 1
+
+xmlfilepath = your_voc_path + 'VOCdevkit_filtered/VOC2007/Annotations/'
+txtsavepath = your_voc_path + 'VOCdevkit_filtered/VOC2007/ImgSets/Main/'
+total_xml = os.listdir(xmlfilepath)
+
+# #划分数据集
+# num = len(total_xml)
+# list = range(num)
+# tv = int(num * trainval_percent)
+# tr = int(tv * train_percent)
+# trainval = random.sample(list, tv)
+# train = random.sample(trainval, tr)
+#
+# ftrainval = open(your_voc_path + 'VOCdevkit_filtered/VOC2007/ImageSets/Main/trainval.txt', 'w')
+# ftest = open(your_voc_path + 'VOCdevkit_filtered/VOC2007/ImageSets/Main/test.txt', 'w')
+# ftrain = open(your_voc_path + 'VOCdevkit_filtered/VOC2007/ImageSets/Main/train.txt', 'w')
+# fval = open(your_voc_path + 'VOCdevkit_filtered/VOC2007/ImageSets/Main/val.txt', 'w')
+#
+# for i in list:
+# name = total_xml[i][:-4] + '\n'
+# if i in trainval:
+# ftrainval.write(name)
+# if i in train:
+# ftrain.write(name)
+# else:
+# fval.write(name)
+# else:
+# ftest.write(name)
+#
+# ftrainval.close()
+# ftrain.close()
+# fval.close()
+# ftest.close()
+print('数据筛选完成')
diff --git a/two classification.py b/two classification.py
new file mode 100644
index 0000000..79b92ff
--- /dev/null
+++ b/two classification.py
@@ -0,0 +1,16 @@
+import os, random, shutil
+def moveFile(fileDir, tarDir):
+ pathDir = os.listdir(fileDir) # 取图片的原始路径
+ filenumber = len(pathDir)
+ rate = 0.1 # 自定义抽取图片的比例,比方说100张抽10张,那就是0.1
+ picknumber = int(filenumber * rate) # 按照rate比例从文件夹中取一定数量图片
+ sample = random.sample(pathDir, picknumber) # 随机选取picknumber数量的样本图片
+ print(sample)
+ for name in sample:
+ shutil.move(fileDir + name, tarDir + "\\" + name)
+ return
+
+if __name__ == '__main__':
+ fileDir = r"D:\snoring-dataset\Snoring Dataset\图像数据\mfcc-dataset\train\ss/" # 源图片文件夹路径
+ tarDir = r'D:\snoring-dataset\Snoring Dataset\图像数据\mfcc-dataset\test\ss/' # 移动到新的文件夹路径
+ moveFile(fileDir, tarDir)
\ No newline at end of file
diff --git a/xbbh.py b/xbbh.py
new file mode 100644
index 0000000..6371951
--- /dev/null
+++ b/xbbh.py
@@ -0,0 +1,46 @@
+import matplotlib.pyplot as plt
+import librosa.display
+import os
+# 批量重命名
+vpath='D:\CloudMusic/ss'
+mps_dir=os.listdir(vpath)
+count=0
+for i,file in enumerate(mps_dir):
+ print(count)
+ filename=os.path.splitext(file)[0]
+ filetype=os.path.splitext(file)[1]
+ if filetype=='.wav':
+ olddir=os.path.join(vpath,file)
+ newdir=os.path.join(vpath,str(count).zfill(6)+'.wav')
+ os.rename(olddir, newdir)
+ count+=1
+
+
+# # 批量转图片-波形图
+# vpath='D:\CloudMusic/ss'
+# mps_dir=os.listdir(vpath)
+# count=0
+# for i,file in enumerate(mps_dir):
+# filename=os.path.splitext(file)[0]
+# filetype=os.path.splitext(file)[1]
+# audio_path=vpath+'/'+file
+# print(audio_path,file,filetype,filename)
+# if filetype=='.wav':
+# music,sr=librosa.load(audio_path)
+# plt.figure(figsize=(4,4))
+# librosa.display.waveplot(music,sr=sr)
+# plt.savefig(vpath+'/'+filename)
+# # # plt.show()
+
+# # 音乐文件载入
+# path='D:\CloudMusic'
+# filename='1.wav'
+# audio_path = path+'/'+filename
+# music, sr = librosa.load(audio_path)
+#
+# # 宽高比为14:5的图
+# plt.figure(figsize=(224, 224))
+# librosa.display.waveplot(music, sr=sr)
+# plt.savefig('D:\CloudMusic/1.jpg')
+# # 显示图
+# plt.show()
diff --git "a/\346\211\271\351\207\217\345\244\204\347\220\206\346\263\242\345\275\242\345\233\276.py" "b/\346\211\271\351\207\217\345\244\204\347\220\206\346\263\242\345\275\242\345\233\276.py"
new file mode 100644
index 0000000..3d06121
--- /dev/null
+++ "b/\346\211\271\351\207\217\345\244\204\347\220\206\346\263\242\345\275\242\345\233\276.py"
@@ -0,0 +1,40 @@
+import pyworld
+import librosa
+import librosa.display
+from IPython.display import Audio
+import numpy as np
+from matplotlib import pyplot as plt
+import math
+import os
+import matplotlib.pyplot as plt
+# 图片风格
+# plt.style.use('seaborn')
+# 存放文件夹路径
+path=r'D:\snoring-dataset\Snoring Dataset\音频数据\0/'
+# 获取文件列表
+waveforms=os.listdir(path)
+pngpath=r'D:\snoring-dataset\Snoring Dataset\音频数据\波形图\no/'
+names=[]
+
+count=0
+# 批量处理wav文件
+for i,file in enumerate(waveforms):
+ file_name=os.path.splitext(file)[0]
+ file_type=os.path.splitext(file)[1]
+ filename=path+file
+ if count%10==0:
+ print(count)
+ # names.append(filename)
+ # 生成波形图
+ x, fs = librosa.load(filename, sr=16000) # librosa load输出的waveform 是 float32
+ x = x.astype(np.double) # 格式转换
+ fftlen = pyworld.get_cheaptrick_fft_size(fs) # 自动计算适合的fftlen
+ # plt.figure()
+ # plt.figure(figsize=(26, 13), dpi=32)
+ plt.figure(figsize=(16, 11), dpi=50)
+ librosa.display.waveplot(x, sr=fs,)
+ # 保存生成的波形图
+ plt.savefig(pngpath+file_name+'.png')
+ # plt.show()
+ count+=1
+# print(names)
\ No newline at end of file
diff --git "a/\346\211\271\351\207\217\347\224\237\346\210\220mfcc\345\233\276.py" "b/\346\211\271\351\207\217\347\224\237\346\210\220mfcc\345\233\276.py"
new file mode 100644
index 0000000..ff32e20
--- /dev/null
+++ "b/\346\211\271\351\207\217\347\224\237\346\210\220mfcc\345\233\276.py"
@@ -0,0 +1,44 @@
+import pyworld
+import librosa
+import librosa.display
+from IPython.display import Audio
+import numpy as np
+from matplotlib import pyplot as plt
+import math
+import os
+import matplotlib.pyplot as plt
+# 图片风格
+# plt.style.use('seaborn')
+# 存放文件夹路径
+path=r'D:\snoring-dataset\Snoring Dataset\音频数据\1-snoring sounds/'
+# 获取文件列表
+waveforms=os.listdir(path)
+pngpath=r'D:\snoring-dataset\Snoring Dataset\音频数据\mfcc\ss/'
+names=[]
+
+count=0
+# 批量处理wav文件
+for i,file in enumerate(waveforms):
+ file_name=os.path.splitext(file)[0]
+ file_type=os.path.splitext(file)[1]
+ filename=path+file
+ print(count)
+ # names.append(filename)
+ # 生成stft声谱图
+ y, sr = librosa.load(filename, sr=16000) # librosa load输出的waveform 是 float32
+ # x = x.astype(np.double) # 格式转换
+ # fftlen = pyworld.get_cheaptrick_fft_size(fs) # 自动计算适合的fftlen
+ melspec = librosa.feature.melspectrogram(y, sr, n_fft=1024, hop_length=512, n_mels=128)
+ logmelspec = librosa.power_to_db(melspec) # 转换为对数刻度
+ plt.figure(figsize=(16, 11), dpi=50)
+ librosa.display.specshow(logmelspec, sr=sr)
+ # 保存生成的波形图
+ plt.savefig(pngpath+file_name+'.png')
+ # plt.show()
+ count+=1
+# print(names)
+
+
+# 绘制 mel 频谱图
+plt.figure()
+
diff --git "a/\346\211\271\351\207\217\347\224\237\346\210\220stft\345\243\260\350\260\261\345\233\276.py" "b/\346\211\271\351\207\217\347\224\237\346\210\220stft\345\243\260\350\260\261\345\233\276.py"
new file mode 100644
index 0000000..2f4f5a6
--- /dev/null
+++ "b/\346\211\271\351\207\217\347\224\237\346\210\220stft\345\243\260\350\260\261\345\233\276.py"
@@ -0,0 +1,41 @@
+import pyworld
+import librosa
+import librosa.display
+from IPython.display import Audio
+import numpy as np
+from matplotlib import pyplot as plt
+import math
+import os
+import matplotlib.pyplot as plt
+# 图片风格
+# plt.style.use('seaborn')
+# 存放文件夹路径
+path=r'D:\snoring-dataset\Snoring Dataset\音频数据\1-snoring sounds/'
+# 获取文件列表
+waveforms=os.listdir(path)
+pngpath=r'D:\snoring-dataset\Snoring Dataset\音频数据\stft\ss/'
+names=[]
+
+count=0
+# 批量处理wav文件
+for i,file in enumerate(waveforms):
+ file_name=os.path.splitext(file)[0]
+ file_type=os.path.splitext(file)[1]
+ filename=path+file
+ print(file_name)
+ # names.append(filename)
+ # 生成stft声谱图
+ # if int(file_name) in [898,899,900,901,902,903,904,905,906,907]:
+ # print("pass")
+ # continue
+ x, fs = librosa.load(filename, sr=16000) # librosa load输出的waveform 是 float32
+ x = x.astype(np.double) # 格式转换
+ fftlen = pyworld.get_cheaptrick_fft_size(fs) # 自动计算适合的fftlen
+ S = librosa.stft(x, n_fft=fftlen)
+ plt.figure(figsize=(16, 11), dpi=50)
+ librosa.display.specshow(np.log(np.abs(S)), sr=fs)
+ # 保存生成的波形图
+ plt.savefig(pngpath+file_name+'.png')
+ # plt.show()
+ count+=1
+# print(names)
diff --git "a/\346\227\266\345\237\237\351\242\221\345\237\237\345\233\276.py" "b/\346\227\266\345\237\237\351\242\221\345\237\237\345\233\276.py"
new file mode 100644
index 0000000..f288082
--- /dev/null
+++ "b/\346\227\266\345\237\237\351\242\221\345\237\237\345\233\276.py"
@@ -0,0 +1,108 @@
+import wave
+import pyaudio
+import pylab
+import numpy as np
+import matplotlib.pyplot as plt
+
+
+def get_framerate(wavefile):
+ '''
+ 输入文件路径,获取帧率
+ '''
+ wf = wave.open(wavfile, "rb") # 打开wav
+ p = pyaudio.PyAudio() # 创建PyAudio对象
+ params = wf.getparams() # 参数获取
+ nchannels, sampwidth, framerate, nframes = params[:4]
+ return framerate
+
+
+def get_nframes(wavefile):
+ '''
+ 输入文件路径,获取帧数
+ '''
+ wf = wave.open(wavfile, "rb") # 打开wav
+ p = pyaudio.PyAudio() # 创建PyAudio对象
+ params = wf.getparams() # 参数获取
+ nchannels, sampwidth, framerate, nframes = params[:4]
+ return nframes
+
+
+def get_wavedata(wavfile):
+ '''
+ 输入文件路径,获取处理好的 N-2 左右声部数组
+ '''
+ #####1.读入wave文件
+ wf = wave.open(wavfile, "rb") # 打开wav
+ p = pyaudio.PyAudio() # 创建PyAudio对象
+ params = wf.getparams() # 参数获取
+ nchannels, sampwidth, framerate, nframes = params[:4]
+ stream = p.open(format=p.get_format_from_width(sampwidth),
+ channels=nchannels,
+ rate=framerate,
+ output=True) # 创建输出流
+ # 读取完整的帧数据到str_data中,这是一个string类型的数据
+ str_data = wf.readframes(nframes)
+ wf.close() # 关闭wave
+
+ #####2.将波形数据转换为数组
+ # N-1 一维数组,右声道接着左声道
+ wave_data = np.frombuffer(str_data, dtype=np.short)
+ # 2-N N维数组
+ wave_data.shape = -1, 2
+ # 将数组转置为 N-2 目标数组
+ wave_data = wave_data.T
+ return wave_data
+
+
+def plot_timedomain(wavfile):
+ '''
+ 画出时域图
+ '''
+ wave_data = get_wavedata(wavfile) # 获取处理好的wave数据
+ framerate = get_framerate(wavfile) # 获取帧率
+ nframes = get_nframes(wavfile) # 获取帧数
+
+ #####3.构建横坐标
+ time = np.arange(0, nframes) * (1.0 / framerate)
+
+ #####4.画图
+ pylab.figure(figsize=(40, 10))
+ pylab.subplot(211)
+ pylab.plot(time, wave_data[0]) # 第一幅图:左声道
+ pylab.subplot(212)
+ pylab.plot(time, wave_data[1], c="g") # 第二幅图:右声道
+ pylab.xlabel("time (seconds)")
+ pylab.show()
+ return None
+
+
+def plot_freqdomain(start, fft_size, wavfile):
+ '''
+ 画出频域图
+ '''
+ waveData = get_wavedata(wavfile) # 获取wave数据
+ framerate = get_framerate(wavfile) # 获取帧率数据
+
+ #### 1.取出所需部分进行傅里叶变换,并得到幅值
+ # rfft,对称保留一半,结果为 fft_size/2-1 维复数数组
+ fft_y1 = np.fft.rfft(waveData[0][start:start + fft_size - 1]) / fft_size # 左声部
+ fft_y2 = np.fft.rfft(waveData[1][start:start + fft_size - 1]) / fft_size # 右声部
+
+ #### 2.计算频域图x值
+ # 最小值为0Hz,最大值一般设为采样频率的一半
+ freqs = np.linspace(0, framerate / 2, fft_size / 2)
+
+ #### 3.画图
+ plt.figure(figsize=(20, 10))
+ pylab.subplot(211)
+ plt.plot(freqs, np.abs(fft_y1))
+ pylab.xlabel("frequence(Hz)")
+ pylab.subplot(212)
+ plt.plot(freqs, np.abs(fft_y2), c='g')
+ pylab.xlabel("frequence(Hz)")
+ plt.show()
+
+
+wavfile='D:\CloudMusic\ss/000005.wav'
+plot_timedomain(wavfile=wavfile)
+plot_freqdomain(10000,4000,wavfile)
diff --git "a/\346\227\266\351\225\277.py" "b/\346\227\266\351\225\277.py"
new file mode 100644
index 0000000..e79c899
--- /dev/null
+++ "b/\346\227\266\351\225\277.py"
@@ -0,0 +1,8 @@
+import contextlib
+import wave
+file_path = r"D:\snoring-dataset\Snoring Dataset\1_1.wav"
+with contextlib.closing(wave.open(file_path, 'r')) as f:
+ frames = f.getnframes()
+ rate = f.getframerate()
+ wav_length = frames / float(rate)
+ print("音频长度:",wav_length,"秒")
diff --git "a/\346\227\266\351\242\221\350\260\261\357\274\214\350\257\255\350\260\261\345\233\276\357\274\214mel\350\257\255\350\260\261\345\200\222\350\260\261.py" "b/\346\227\266\351\242\221\350\260\261\357\274\214\350\257\255\350\260\261\345\233\276\357\274\214mel\350\257\255\350\260\261\345\200\222\350\260\261.py"
new file mode 100644
index 0000000..1b257a9
--- /dev/null
+++ "b/\346\227\266\351\242\221\350\260\261\357\274\214\350\257\255\350\260\261\345\233\276\357\274\214mel\350\257\255\350\260\261\345\200\222\350\260\261.py"
@@ -0,0 +1,76 @@
+import matplotlib
+import pyworld
+import librosa
+import librosa.display
+from IPython.display import Audio
+import numpy as np
+from matplotlib import pyplot as plt
+import math
+# plt.style.use('seaborn-white')
+# plt.style.use('seaborn')
+# 波形图
+x, fs = librosa.load("D:\snoring-dataset\Snoring Dataset\音频数据/0-non-snoring sounds/000869.wav", sr=16000) #librosa load输出的waveform 是 float32
+x = x.astype(np.double) # 格式转换
+
+fftlen = pyworld.get_cheaptrick_fft_size(fs)#自动计算适合的fftlen
+# 波形图
+# plt.figure(figsize=(26,13),dpi=32)
+# # plt.figure()
+# librosa.display.waveplot(x, sr=fs,x_axis=None,)
+# # plt.savefig('D:\snoring-dataset\Snoring Dataset/000000-0.png')
+# plt.show()
+# Audio(x, rate=fs)
+# 生成语谱图
+# plt.figure()
+# plt.specgram(x,NFFT=fftlen, Fs=fs,noverlap=fftlen*1/4, window=np.hanning(fftlen))
+# # plt.ylabel('Frequency')
+# # plt.xlabel('Time(s)')
+# # plt.title('specgram')
+# plt.show()
+#功率谱图
+# D = librosa.amplitude_to_db(librosa.stft(x), ref=np.max)#20log|x|
+# plt.figure()
+# # librosa.display.specshow(D, sr=fs, hop_length=fftlen*1/4,y_axis='linear')
+# librosa.display.specshow(D, sr=fs,hop_length=fftlen*1/4)
+# # plt.colorbar(format='%+2.0f dB')
+# # plt.title('Linear-frequency power spectrogram')
+# plt.show()
+
+# STFT时频图
+# S = librosa.stft(x,n_fft=fftlen) # 幅值
+# plt.figure()
+# # librosa.display.specshow(np.log(np.abs(S)), sr=fs,hop_length=fftlen/4)
+# librosa.display.specshow(np.log(np.abs(S)), sr=fs)
+# # plt.colorbar()
+# # plt.title('STFT')
+# plt.savefig('1')
+# plt.show()
+
+
+# mel spectrogram 梅尔语谱图
+# melspec = librosa.feature.melspectrogram(x, sr=fs, n_fft=fftlen, n_mels=128) #(128,856)
+# logmelspec = librosa.power_to_db(melspec)# (128,856)
+# plt.figure()
+# # librosa.display.specshow(logmelspec, sr=fs, x_axis='time', y_axis='mel')
+# librosa.display.specshow(logmelspec, sr=fs)
+# # plt.title('log melspectrogram')
+# plt.show()
+
+# MFCC
+y, sr = librosa.load('D:\snoring-dataset\Snoring Dataset/1_0.wav', sr=16000)
+# 提取 mel spectrogram feature
+# melspec = librosa.feature.melspectrogram(y, sr, n_fft=1024, hop_length=512, n_mels=128)
+melspec = librosa.feature.melspectrogram(y, sr, n_fft=1024, hop_length=512, n_mels=128)
+logmelspec = librosa.power_to_db(melspec) # 转换为对数刻度
+# 绘制 mel 频谱图
+plt.figure()
+librosa.display.specshow(logmelspec, sr=sr)
+# librosa.display.specshow(logmelspec, sr=sr, x_axis='time', y_axis='mel')
+# plt.colorbar(format='%+2.0f dB') # 右边的色度条
+# plt.title('Beat wavform')
+plt.show()
+
+
+
+
+
diff --git "a/\346\263\242\345\275\242\345\233\276.py" "b/\346\263\242\345\275\242\345\233\276.py"
new file mode 100644
index 0000000..07d3778
--- /dev/null
+++ "b/\346\263\242\345\275\242\345\233\276.py"
@@ -0,0 +1,18 @@
+import pyworld
+import librosa
+import librosa.display
+from IPython.display import Audio
+import numpy as np
+from matplotlib import pyplot as plt
+# plt.style.use('seaborn-white')
+# plt.style.use('seaborn')
+# 波形图
+x, fs = librosa.load("D:\snoring-dataset\Snoring Dataset/1_0.wav", sr=16000) #librosa load输出的waveform 是 float32
+x = x.astype(np.double) # 格式转换
+fftlen = pyworld.get_cheaptrick_fft_size(fs)#自动计算适合的fftlen
+# 波形图
+plt.figure(figsize=(16,11),dpi=50)
+# plt.figure()
+librosa.display.waveplot(x, sr=fs,x_axis=None,)
+plt.savefig('D:\snoring-dataset\Snoring Dataset/1_0-1.png')
+plt.show()
\ No newline at end of file
diff --git "a/\347\224\273\345\233\276-\346\237\261\347\212\266\345\233\276.py" "b/\347\224\273\345\233\276-\346\237\261\347\212\266\345\233\276.py"
new file mode 100644
index 0000000..42469d2
--- /dev/null
+++ "b/\347\224\273\345\233\276-\346\237\261\347\212\266\345\233\276.py"
@@ -0,0 +1,19 @@
+import os
+import matplotlib.pyplot as plt
+plt.style.use('seaborn')
+type=['Snoring-kaggle','No-Snoring-kaggle','Snoring-ESC50','No-Snoring-ESC50']
+num=[500,500,40,40]
+plt.figure()
+x_ticks = range(len(type))
+# plt.bar(x_ticks, num, color=['b','r','g','y','c','m','y','k','c','g','b'])
+plt.bar(x_ticks, num,color=['cornflowerblue','cornflowerblue','c','c'])
+# 修改x刻度
+plt.xticks(x_ticks, type)
+# 添加标题
+plt.title("Snoring-Dataset")
+# 添加网格显示
+plt.grid(linestyle="--", alpha=0.7)
+# plt.legend(loc='upper center', fontsize=15, ncol=2)
+# 4、显示图像
+plt.savefig('数量柱状图.png')
+plt.show()
\ No newline at end of file
diff --git "a/\351\245\274\345\233\276.py" "b/\351\245\274\345\233\276.py"
new file mode 100644
index 0000000..142b591
--- /dev/null
+++ "b/\351\245\274\345\233\276.py"
@@ -0,0 +1,19 @@
+import os
+import matplotlib.pyplot as plt
+plt.style.use('seaborn')
+type=['Snoring-kaggle','No-Snoring-kaggle','Snoring-ESC50','No-Snoring-ESC50']
+num=[500,500,40,40]
+plt.figure()
+# 2、创建画布
+
+
+# 3、绘制饼图
+plt.pie(num, labels=type, colors=['dodgerblue','red','springgreen','y'], autopct="%1.2f%%")
+
+# 显示图例
+plt.legend()
+
+plt.axis('equal')
+plt.savefig('饼图.png')
+# 4、显示图像
+plt.show()