-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 4182138
Showing
20 changed files
with
6,999 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# fast_word_emotion_analysis |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import sounddevice as sd | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import torch | ||
import threading | ||
from transformers import WhisperProcessor, WhisperForConditionalGeneration | ||
import queue | ||
|
||
# SETTINGS | ||
BLOCKSIZE = 24678 // 5 | ||
SILENCE_THRESHOLD = 700 | ||
MIN_AUDIO_LENGTH = 8000 | ||
SILENCE_RATIO = 300 | ||
SAVE_PATH = "transcriptions.txt" | ||
|
||
# Initialize Whisper model and processor | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
model_name = "vumichien/whisper-small-ja" | ||
processor = WhisperProcessor.from_pretrained(model_name) | ||
model = WhisperForConditionalGeneration.from_pretrained(model_name).to(device) | ||
model = model.half() | ||
forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe") | ||
|
||
global_ndarray = None | ||
audio_queue = queue.Queue() | ||
|
||
running = True | ||
|
||
|
||
def audio_capture_thread(): | ||
with sd.InputStream( | ||
samplerate=16000, channels=1, dtype="int16", blocksize=BLOCKSIZE | ||
) as stream: | ||
while running: | ||
indata, status = stream.read(BLOCKSIZE) | ||
audio_queue.put(indata) | ||
|
||
audio_queue.put(None) # Sentinel value to indicate end of stream | ||
|
||
|
||
def transcription_and_plotting(): | ||
plt.ion() | ||
fig, ax = plt.subplots() | ||
(line,) = ax.plot(np.random.randn(BLOCKSIZE)) | ||
ax.set_ylim([-(2**15), 2**15 - 1]) | ||
ax.set_xlim(0, BLOCKSIZE) | ||
|
||
global global_ndarray | ||
|
||
while running: | ||
indata = audio_queue.get() | ||
if indata is None: # If end of stream sentinel is found, break the loop | ||
break | ||
|
||
indata_flattened = abs(indata.flatten()) | ||
|
||
line.set_ydata(indata) | ||
plt.draw() | ||
plt.pause(0.001) | ||
|
||
is_significant_audio = ( | ||
np.asarray(np.where(indata_flattened > SILENCE_THRESHOLD)).size >= SILENCE_RATIO | ||
) | ||
|
||
if is_significant_audio: | ||
if global_ndarray is not None: | ||
global_ndarray = np.concatenate((global_ndarray, indata), dtype="int16") | ||
else: | ||
global_ndarray = indata | ||
elif global_ndarray is not None: | ||
if len(global_ndarray) < MIN_AUDIO_LENGTH: | ||
continue | ||
indata_transformed = global_ndarray.flatten().astype(np.float32) / 32768.0 | ||
global_ndarray = None | ||
input_data = processor( | ||
indata_transformed, sampling_rate=16000, return_tensors="pt" | ||
).input_features | ||
input_data = input_data.half() | ||
predicted_ids = model.generate( | ||
input_data.to(device), forced_decoder_ids=forced_decoder_ids | ||
) | ||
|
||
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | ||
print(f"Transcription: {transcription}") | ||
|
||
# with open(SAVE_PATH, "a", encoding="utf-8", buffering=0) as file: | ||
with open(SAVE_PATH, "a", encoding="utf-8") as file: | ||
file.write(transcription + "\n") | ||
file.flush() | ||
|
||
|
||
if __name__ == "__main__": | ||
capture_thread = threading.Thread(target=audio_capture_thread) | ||
capture_thread.start() | ||
|
||
try: | ||
transcription_and_plotting() | ||
except KeyboardInterrupt: | ||
print("\nInterrupted by user") | ||
running = False | ||
plt.close() | ||
|
||
capture_thread.join() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
from transformers import AutoTokenizer, AutoModelForSequenceClassification | ||
import torch | ||
from matplotlib.animation import FuncAnimation | ||
import os | ||
|
||
plt.rcParams["font.family"] = "Meiryo" | ||
|
||
SAVE_PATH = "transcriptions.txt" | ||
|
||
tokenizer = AutoTokenizer.from_pretrained( | ||
"Mizuiro-sakura/luke-japanese-large-sentiment-analysis-wrime" | ||
) | ||
model = AutoModelForSequenceClassification.from_pretrained( | ||
"Mizuiro-sakura/luke-japanese-large-sentiment-analysis-wrime" | ||
) | ||
|
||
emotions = ["喜び", "悲しみ", "期待", "驚き", "怒り", "恐れ", "嫌悪", "信頼", "喜び"] | ||
|
||
|
||
def get_emotion_probs(text): | ||
token = tokenizer( | ||
text, return_tensors="pt", truncation=True, max_length=512, padding="max_length" | ||
) | ||
output = model(**token) | ||
normalized_logits = (output.logits - torch.min(output.logits)) / ( | ||
torch.max(output.logits) - torch.min(output.logits) | ||
) | ||
probs = normalized_logits.squeeze().tolist() | ||
probs.append(probs[0]) | ||
return probs | ||
|
||
|
||
fig, ax = plt.subplots(subplot_kw={"projection": "polar"}) | ||
ax.set_ylim(0, 1) | ||
theta = np.linspace(0, 2 * np.pi, len(emotions), endpoint=True) | ||
(line,) = ax.plot(theta, [0] * len(emotions)) | ||
ax.set_xticks(theta) | ||
ax.set_xticklabels(emotions) | ||
|
||
last_read_line = 0 | ||
last_mtime = os.path.getmtime(SAVE_PATH) # 最後に確認したファイルの修正時間 | ||
|
||
serialPort = "/dev/ttyUSB0" | ||
baudRate = 115200 | ||
ser = serial.Serial(serialPort, baudRate, timeout=1) | ||
time.sleep(2) | ||
|
||
# 感情と色の対応 | ||
emotion_colors = { | ||
"喜び": "Y", # 黄色 | ||
"悲しみ": "B", # 青色 | ||
"期待": "G", # 緑色 | ||
# 他の感情に対応する色も同様に定義 | ||
} | ||
|
||
|
||
def send_color_to_esp32(color_code): | ||
ser.write(color_code.encode()) | ||
|
||
|
||
def update(frame): | ||
global last_read_line, last_mtime | ||
|
||
current_mtime = os.path.getmtime(SAVE_PATH) | ||
|
||
if current_mtime != last_mtime: | ||
with open(SAVE_PATH, "r", encoding="utf-8") as file: | ||
lines = file.readlines() | ||
if last_read_line < len(lines): | ||
text = lines[-1].strip() | ||
emotion_probs = get_emotion_probs(text) | ||
line.set_ydata(emotion_probs) | ||
last_read_line = len(lines) | ||
last_mtime = current_mtime | ||
|
||
highest_emotion = emotions[emotion_probs.index(max(emotion_probs))] | ||
color_code = emotion_colors.get(highest_emotion, "R") | ||
send_color_to_esp32(color_code) # ESP32に色を送信 | ||
|
||
return (line,) | ||
|
||
|
||
MAX_FRAMES = 100 | ||
|
||
ani = FuncAnimation(fig, update, repeat=True, blit=True, save_count=MAX_FRAMES) | ||
|
||
plt.show() |
Oops, something went wrong.