-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 1e78a5b
Showing
9 changed files
with
2,213 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
/web_api/target |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
Copyright (c) 2024 Denis Avvakumov | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in | ||
all copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | ||
THE SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Intro | ||
A friend of mine asked me to demonstrate how to load a model trained in Python into a Rust service. In response, this repository showcases the entire process of training a machine learning model to distinguish between various text encodings, achieving around 98.5% validation accuracy, using data sourced from the English Wiktionary. | ||
Subsequently, the trained model is seamlessly integrated into a Rust-based microservice, utilizing the ntex-rs. This implementation is streamlined with minimal dependencies, ensuring a lightweight and efficient service. | ||
|
||
<br/> | ||
Supported encodings: | ||
|
||
1. Plain text | ||
2. Rot13 | ||
3. Caesar | ||
4. Base85 | ||
5. Base64 | ||
6. Base58 | ||
|
||
## Dependencies | ||
|
||
- Python 3.10+ | ||
- Rust 1.75+ | ||
|
||
## Train Model | ||
1. Download the [English Wiktionary dump](https://dumps.wikimedia.org/enwiktionary/). | ||
2. Open `test.ipynb` and modify the variable `wiktionary_dump_filepath` to point to the downloaded dump file. | ||
3. Execute the first cell in the notebook. | ||
4. Execute the second cell in the notebook to train the model (ensure that all dependencies associated with TensorFlow and Keras are properly installed). | ||
5. To evaluate the model, run the third cell. | ||
|
||
## Usage of web_api | ||
|
||
Run in the web_api directory: | ||
``` | ||
cargo run --release | ||
``` | ||
|
||
Run the following command in terminal: | ||
|
||
``` | ||
curl -X POST -H "Content-Type: application/json" -d '{"language":"English","data":"HELLO WORLD"}' http://127.0.0.1:3000/predict | ||
``` | ||
|
||
The prediction will be presented in the following format: | ||
|
||
```json | ||
{"text":"99.51","rot13":"0.00","caesar":"0.49","base85":"0.00","base64":"0.00","base58":"0.00"} | ||
``` | ||
|
||
## License | ||
This project is licensed under the [MIT license](LICENSE). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,303 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import xml.etree.ElementTree as ET\n", | ||
"import unicodedata\n", | ||
"\n", | ||
"\n", | ||
"def is_acceptable_character(c):\n", | ||
" # Check for basic Latin letters and digits\n", | ||
" if c.isascii() and (c.isalpha() or c.isdigit()):\n", | ||
" return True\n", | ||
" return False\n", | ||
"\n", | ||
"\n", | ||
"def filter_words(words):\n", | ||
" filtered_words = []\n", | ||
" for word in words:\n", | ||
" # Reconstruct each word using only acceptable characters\n", | ||
" filtered_word = \"\".join(c for c in word if is_acceptable_character(c))\n", | ||
" # Check if not empty and does not start with a digit\n", | ||
" if filtered_word:\n", | ||
" filtered_words.append(filtered_word)\n", | ||
" return filtered_words\n", | ||
"\n", | ||
"\n", | ||
"def extract_words_by_language(input_filename, output_filename, language=\"English\"):\n", | ||
" words_set = set()\n", | ||
"\n", | ||
" # Define the language marker we're looking for in the content\n", | ||
" language_marker = f\"=={language}==\"\n", | ||
"\n", | ||
" # For storing the title temporarily\n", | ||
" current_title = None\n", | ||
"\n", | ||
" # Create an iterable parsing of the XML file\n", | ||
" context = ET.iterparse(input_filename, events=(\"start\", \"end\"))\n", | ||
" context = iter(context)\n", | ||
"\n", | ||
" # Get the root element\n", | ||
" event, root = next(context)\n", | ||
"\n", | ||
" for event, elem in context:\n", | ||
" if event == \"end\" and elem.tag.endswith(\"title\"):\n", | ||
" current_title = elem.text\n", | ||
" elif event == \"end\" and elem.tag.endswith(\"text\"):\n", | ||
" # Check that elem.text is not None before attempting to search it\n", | ||
" if elem.text and language_marker in elem.text and current_title:\n", | ||
" # Split the title into individual words on spaces\n", | ||
" for word in current_title.split():\n", | ||
" # Add each word to the set\n", | ||
" words_set.add(word)\n", | ||
" current_title = None\n", | ||
"\n", | ||
" # Clear the element to save memory\n", | ||
" root.clear()\n", | ||
"\n", | ||
" words_set = filter_words(words_set)\n", | ||
" # Write the filtered and individualized words to a file\n", | ||
" with open(output_filename, \"w\", encoding=\"utf-8\") as f:\n", | ||
" for word in sorted(words_set): # Sorting for easier readability\n", | ||
" f.write(f\"{word}\\n\")\n", | ||
"\n", | ||
"\n", | ||
"if __name__ == \"__main__\":\n", | ||
" wiktionary_dump_filepath = \"enwiktionary-20240201-pages-articles.xml\"\n", | ||
" output_filepath = \"english_wiktionary_words.txt\"\n", | ||
" extract_words_by_language(wiktionary_dump_filepath, output_filepath)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import random\n", | ||
"from base64 import b64encode, b85encode\n", | ||
"import base58\n", | ||
"from tensorflow.keras.models import Sequential\n", | ||
"from tensorflow.keras.layers import Embedding, Dense, LSTM, Dropout, Bidirectional\n", | ||
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n", | ||
"from tensorflow.keras.utils import to_categorical\n", | ||
"from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint\n", | ||
"from tensorflow.keras.optimizers import Adam\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"\n", | ||
"\n", | ||
"# Encoding Functions\n", | ||
"def encode_text(text, method=\"rot13\", shift=3):\n", | ||
" if method == \"text\":\n", | ||
" return text\n", | ||
" elif method == \"rot13\":\n", | ||
" return text.translate(\n", | ||
" str.maketrans(\n", | ||
" \"ABCDEFGHIJKLMabcdefghijklmNOPQRSTUVWXYZnopqrstuvwxyz\",\n", | ||
" \"NOPQRSTUVWXYZnopqrstuvwxyzABCDEFGHIJKLMabcdefghijklm\",\n", | ||
" )\n", | ||
" )\n", | ||
" elif method == \"caesar\":\n", | ||
" return \"\".join(\n", | ||
" (\n", | ||
" chr((ord(char) - 65 + shift) % 26 + 65)\n", | ||
" if char.isupper()\n", | ||
" else chr((ord(char) - 97 + shift) % 26 + 97) if char.islower() else char\n", | ||
" )\n", | ||
" for char in text\n", | ||
" )\n", | ||
" elif method == \"base85\":\n", | ||
" return b85encode(text.encode()).decode()\n", | ||
" elif method == \"base64\":\n", | ||
" return b64encode(text.encode()).decode()\n", | ||
" elif method == \"base58\":\n", | ||
" return base58.b58encode(text.encode()).decode()\n", | ||
"\n", | ||
"\n", | ||
"def load_and_preprocess(file_path, max_lines=160000):\n", | ||
" data, labels = [], []\n", | ||
" lines = open(file_path, \"r\").read().split(\"\\n\")\n", | ||
" random.shuffle(lines)\n", | ||
"\n", | ||
" for j, line in enumerate(lines):\n", | ||
" if j > max_lines:\n", | ||
" break\n", | ||
"\n", | ||
" sentence = line.strip()\n", | ||
" sentece_upper = sentence.upper()\n", | ||
"\n", | ||
" encodings = [\"text\", \"rot13\", \"caesar\", \"base85\", \"base64\", \"base58\"]\n", | ||
" orig_methods = encodings[:]\n", | ||
" random.shuffle(encodings)\n", | ||
"\n", | ||
" for i, method in enumerate(encodings):\n", | ||
" data.append(encode_text(sentence, method))\n", | ||
" data.append(encode_text(sentece_upper, method))\n", | ||
" labels.append(orig_methods.index(method))\n", | ||
" labels.append(orig_methods.index(method))\n", | ||
"\n", | ||
" return data, to_categorical(labels, num_classes=6)\n", | ||
"\n", | ||
"\n", | ||
"# Load data\n", | ||
"file_path = \"english_wiktionary_words.txt\"\n", | ||
"data, labels = load_and_preprocess(file_path)\n", | ||
"\n", | ||
"# Character-Level tokenization and sequencing\n", | ||
"max_length = 128\n", | ||
"chars = [chr(i) for i in range(128)]\n", | ||
"char_to_index = {c: i + 1 for i, c in enumerate(chars)}\n", | ||
"sequences = [[char_to_index.get(char, 0) for char in doc] for doc in data]\n", | ||
"padded = pad_sequences(sequences, maxlen=max_length, padding=\"post\")\n", | ||
"vocab_size = len(char_to_index) + 1\n", | ||
"\n", | ||
"# Model definition\n", | ||
"model = Sequential(\n", | ||
" [\n", | ||
" Embedding(vocab_size, 256, name=\"predict\", input_length=max_length),\n", | ||
" Bidirectional(LSTM(128)),\n", | ||
" Dense(256, activation=\"relu\"),\n", | ||
" Dropout(0.5),\n", | ||
" Dense(6, name=\"predict_output\", activation=\"softmax\"),\n", | ||
" ]\n", | ||
")\n", | ||
"\n", | ||
"model.compile(\n", | ||
" loss=\"categorical_crossentropy\",\n", | ||
" optimizer=Adam(learning_rate=0.0004),\n", | ||
" metrics=[\"accuracy\"],\n", | ||
")\n", | ||
"model.summary()\n", | ||
"\n", | ||
"# Train model\n", | ||
"X_train, X_test = padded[: int(len(padded) * 0.8)], padded[int(len(padded) * 0.8) :]\n", | ||
"y_train, y_test = labels[: int(len(labels) * 0.8)], labels[int(len(labels) * 0.8) :]\n", | ||
"\n", | ||
"mcp_save = ModelCheckpoint(\n", | ||
" \"detector.keras\", save_best_only=True, monitor=\"val_loss\", mode=\"min\"\n", | ||
")\n", | ||
"# early_stop = EarlyStopping(monitor='val_loss', patience=2, restore_best_weights=True)\n", | ||
"model.fit(\n", | ||
" X_train,\n", | ||
" y_train,\n", | ||
" epochs=5,\n", | ||
" batch_size=512,\n", | ||
" validation_data=(X_test, y_test),\n", | ||
" callbacks=[mcp_save],\n", | ||
")\n", | ||
"model.save(\"detector\", save_format=\"tf\")\n", | ||
"\n", | ||
"# Plot training and validation Loss\n", | ||
"history = model.history.history\n", | ||
"plt.plot(history[\"loss\"], \"g\", label=\"Training loss\")\n", | ||
"plt.plot(history[\"val_loss\"], \"r\", label=\"Validation loss\")\n", | ||
"plt.title(\"Training and Validation Loss\")\n", | ||
"plt.xlabel(\"Epochs\")\n", | ||
"plt.ylabel(\"Loss\")\n", | ||
"plt.legend()\n", | ||
"plt.show()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"from keras.models import load_model\n", | ||
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n", | ||
"from base64 import b64encode, b85encode\n", | ||
"import base58\n", | ||
"\n", | ||
"\n", | ||
"def encode_text(text, method=\"rot13\", shift=3):\n", | ||
" if method == \"text\":\n", | ||
" return text\n", | ||
" elif method == \"rot13\":\n", | ||
" return text.translate(\n", | ||
" str.maketrans(\n", | ||
" \"ABCDEFGHIJKLMabcdefghijklmNOPQRSTUVWXYZnopqrstuvwxyz\",\n", | ||
" \"NOPQRSTUVWXYZnopqrstuvwxyzABCDEFGHIJKLMabcdefghijklm\",\n", | ||
" )\n", | ||
" )\n", | ||
" elif method == \"caesar\":\n", | ||
" return \"\".join(\n", | ||
" (\n", | ||
" chr((ord(char) - 65 + shift) % 26 + 65)\n", | ||
" if char.isupper()\n", | ||
" else chr((ord(char) - 97 + shift) % 26 + 97) if char.islower() else char\n", | ||
" )\n", | ||
" for char in text\n", | ||
" )\n", | ||
" elif method == \"base85\":\n", | ||
" return b85encode(text.encode()).decode()\n", | ||
" elif method == \"base64\":\n", | ||
" return b64encode(text.encode()).decode()\n", | ||
" elif method == \"base58\":\n", | ||
" return base58.b58encode(text.encode()).decode()\n", | ||
"\n", | ||
"\n", | ||
"max_length = 128\n", | ||
"chars = [chr(i) for i in range(128)]\n", | ||
"char_to_index = {c: i + 1 for i, c in enumerate(chars)}\n", | ||
"\n", | ||
"\n", | ||
"# Function to preprocess the input text in the same way as the training data\n", | ||
"def preprocess_input_text(text):\n", | ||
" encoded_texts = [\n", | ||
" encode_text(text, method)\n", | ||
" for method in [\"text\", \"rot13\", \"caesar\", \"base85\", \"base64\", \"base58\"]\n", | ||
" ]\n", | ||
" sequences = [[char_to_index.get(char, 0) for char in doc] for doc in encoded_texts]\n", | ||
" padded_seq = pad_sequences(sequences, maxlen=max_length, padding=\"post\")\n", | ||
" return padded_seq\n", | ||
"\n", | ||
"\n", | ||
"# Function to predict the encoding of the text\n", | ||
"def predict(model_path, input_text):\n", | ||
" model = load_model(model_path)\n", | ||
" preprocessed_text = preprocess_input_text(input_text)\n", | ||
" predictions = model.predict(preprocessed_text)\n", | ||
" encodings = [\"text\", \"rot13\", \"caesar\", \"base85\", \"base64\", \"base58\"]\n", | ||
" for method, prediction in zip(encodings, predictions):\n", | ||
" print(\"Encoding:\", method)\n", | ||
" print(\"Predicted encoding:\", encodings[np.argmax(prediction)])\n", | ||
" print(\n", | ||
" \"Predicted encoding percentages:\", [\"{:.2%}\".format(p) for p in prediction]\n", | ||
" )\n", | ||
"\n", | ||
"\n", | ||
"# Example usage\n", | ||
"model_path = \"detector.keras\"\n", | ||
"input_text = \"hello there\"\n", | ||
"predict(model_path, input_text)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.12.5" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.