Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
kasi-x authored Apr 26, 2024
0 parents commit 4182138
Show file tree
Hide file tree
Showing 20 changed files with 6,999 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# fast_word_emotion_analysis
103 changes: 103 additions & 0 deletions dev/a.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import sounddevice as sd
import numpy as np
import matplotlib.pyplot as plt
import torch
import threading
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import queue

# SETTINGS
BLOCKSIZE = 24678 // 5
SILENCE_THRESHOLD = 700
MIN_AUDIO_LENGTH = 8000
SILENCE_RATIO = 300
SAVE_PATH = "transcriptions.txt"

# Initialize Whisper model and processor
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "vumichien/whisper-small-ja"
processor = WhisperProcessor.from_pretrained(model_name)
model = WhisperForConditionalGeneration.from_pretrained(model_name).to(device)
model = model.half()
forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe")

global_ndarray = None
audio_queue = queue.Queue()

running = True


def audio_capture_thread():
with sd.InputStream(
samplerate=16000, channels=1, dtype="int16", blocksize=BLOCKSIZE
) as stream:
while running:
indata, status = stream.read(BLOCKSIZE)
audio_queue.put(indata)

audio_queue.put(None) # Sentinel value to indicate end of stream


def transcription_and_plotting():
plt.ion()
fig, ax = plt.subplots()
(line,) = ax.plot(np.random.randn(BLOCKSIZE))
ax.set_ylim([-(2**15), 2**15 - 1])
ax.set_xlim(0, BLOCKSIZE)

global global_ndarray

while running:
indata = audio_queue.get()
if indata is None: # If end of stream sentinel is found, break the loop
break

indata_flattened = abs(indata.flatten())

line.set_ydata(indata)
plt.draw()
plt.pause(0.001)

is_significant_audio = (
np.asarray(np.where(indata_flattened > SILENCE_THRESHOLD)).size >= SILENCE_RATIO
)

if is_significant_audio:
if global_ndarray is not None:
global_ndarray = np.concatenate((global_ndarray, indata), dtype="int16")
else:
global_ndarray = indata
elif global_ndarray is not None:
if len(global_ndarray) < MIN_AUDIO_LENGTH:
continue
indata_transformed = global_ndarray.flatten().astype(np.float32) / 32768.0
global_ndarray = None
input_data = processor(
indata_transformed, sampling_rate=16000, return_tensors="pt"
).input_features
input_data = input_data.half()
predicted_ids = model.generate(
input_data.to(device), forced_decoder_ids=forced_decoder_ids
)

transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
print(f"Transcription: {transcription}")

# with open(SAVE_PATH, "a", encoding="utf-8", buffering=0) as file:
with open(SAVE_PATH, "a", encoding="utf-8") as file:
file.write(transcription + "\n")
file.flush()


if __name__ == "__main__":
capture_thread = threading.Thread(target=audio_capture_thread)
capture_thread.start()

try:
transcription_and_plotting()
except KeyboardInterrupt:
print("\nInterrupted by user")
running = False
plt.close()

capture_thread.join()
89 changes: 89 additions & 0 deletions dev/b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import matplotlib.pyplot as plt
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from matplotlib.animation import FuncAnimation
import os

plt.rcParams["font.family"] = "Meiryo"

SAVE_PATH = "transcriptions.txt"

tokenizer = AutoTokenizer.from_pretrained(
"Mizuiro-sakura/luke-japanese-large-sentiment-analysis-wrime"
)
model = AutoModelForSequenceClassification.from_pretrained(
"Mizuiro-sakura/luke-japanese-large-sentiment-analysis-wrime"
)

emotions = ["喜び", "悲しみ", "期待", "驚き", "怒り", "恐れ", "嫌悪", "信頼", "喜び"]


def get_emotion_probs(text):
token = tokenizer(
text, return_tensors="pt", truncation=True, max_length=512, padding="max_length"
)
output = model(**token)
normalized_logits = (output.logits - torch.min(output.logits)) / (
torch.max(output.logits) - torch.min(output.logits)
)
probs = normalized_logits.squeeze().tolist()
probs.append(probs[0])
return probs


fig, ax = plt.subplots(subplot_kw={"projection": "polar"})
ax.set_ylim(0, 1)
theta = np.linspace(0, 2 * np.pi, len(emotions), endpoint=True)
(line,) = ax.plot(theta, [0] * len(emotions))
ax.set_xticks(theta)
ax.set_xticklabels(emotions)

last_read_line = 0
last_mtime = os.path.getmtime(SAVE_PATH) # 最後に確認したファイルの修正時間

serialPort = "/dev/ttyUSB0"
baudRate = 115200
ser = serial.Serial(serialPort, baudRate, timeout=1)
time.sleep(2)

# 感情と色の対応
emotion_colors = {
"喜び": "Y", # 黄色
"悲しみ": "B", # 青色
"期待": "G", # 緑色
# 他の感情に対応する色も同様に定義
}


def send_color_to_esp32(color_code):
ser.write(color_code.encode())


def update(frame):
global last_read_line, last_mtime

current_mtime = os.path.getmtime(SAVE_PATH)

if current_mtime != last_mtime:
with open(SAVE_PATH, "r", encoding="utf-8") as file:
lines = file.readlines()
if last_read_line < len(lines):
text = lines[-1].strip()
emotion_probs = get_emotion_probs(text)
line.set_ydata(emotion_probs)
last_read_line = len(lines)
last_mtime = current_mtime

highest_emotion = emotions[emotion_probs.index(max(emotion_probs))]
color_code = emotion_colors.get(highest_emotion, "R")
send_color_to_esp32(color_code) # ESP32に色を送信

return (line,)


MAX_FRAMES = 100

ani = FuncAnimation(fig, update, repeat=True, blit=True, save_count=MAX_FRAMES)

plt.show()
Loading

0 comments on commit 4182138

Please sign in to comment.