Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions btcopilot/pro/models/diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from btcopilot.schema import DiagramData, PDP, from_dict
from btcopilot.extensions import db
from btcopilot.modelmixin import ModelMixin
from btcopilot.pro.safe_pickle import safe_loads_diagram


# TODO: Remove once pro version adoption gets past 2.1.11
Expand Down Expand Up @@ -83,7 +84,7 @@ class Diagram(db.Model, ModelMixin):
def get_diagram_data(self) -> DiagramData:
import PyQt5.sip # Required for unpickling QtCore objects

data = pickle.loads(self.data) if self.data else {}
data = safe_loads_diagram(self.data) if self.data else {}
pdp_dict = data.get("pdp", {})
known = {f.name for f in dc_fields(DiagramData)} - {"pdp"}
kwargs = {k: data[k] for k in known if k in data}
Expand All @@ -94,7 +95,7 @@ def set_diagram_data(self, diagram_data: DiagramData):
import PyQt5.sip # Required for pickling QtCore objects
from btcopilot.schema import asdict

data = pickle.loads(self.data) if self.data else {}
data = safe_loads_diagram(self.data) if self.data else {}

# Convert PDP dataclass to dict before pickling (JSON-compatible)
data["pdp"] = asdict(diagram_data.pdp)
Expand Down Expand Up @@ -154,7 +155,7 @@ def update_with_version_check(
import PyQt5.sip
from btcopilot.schema import asdict

data = pickle.loads(self.data) if self.data else {}
data = safe_loads_diagram(self.data) if self.data else {}
data["pdp"] = asdict(diagram_data.pdp)
data["lastItemId"] = diagram_data.lastItemId
data["people"] = diagram_data.people
Expand Down
42 changes: 22 additions & 20 deletions btcopilot/pro/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import sys
import re
import pickle

from btcopilot.pro.safe_pickle import safe_loads
import ast
import datetime
import uuid
Expand Down Expand Up @@ -201,7 +203,7 @@ def diagrams(id=None):
# _log.debug(f" Diagram[{diagram['id']}].updated_at: {diagram['updated_at']}")
return pickle.dumps(data)
elif request.method == "POST": # create
args = pickle.loads(request.data)
args = safe_loads(request.data)
diagram = Diagram(
user_id=g.user.id,
name=args["name"],
Expand Down Expand Up @@ -233,7 +235,7 @@ def diagrams(id=None):
elif request.method in ("PATCH", "PUT"): # update
if not diagram.check_write_access(g.user):
return ("Access Denied", 401)
data = pickle.loads(request.data)
data = safe_loads(request.data)
expected_version = data.get("expected_version")

# Support either sending a pickled dict of db model attributes or the pickled scene data.
Expand Down Expand Up @@ -281,7 +283,7 @@ def diagrams(id=None):
@bp.route("/users/status", methods=("POST",))
@encrypted
def users_status():
args = pickle.loads(request.data)
args = safe_loads(request.data)
user = User.query.filter_by(username=args["username"].lower()).first()
if user:
data = {"status": user.status, "id": user.id}
Expand All @@ -304,7 +306,7 @@ def users_status():
@bp.route("/users", methods=("POST",))
@encrypted
def users_create():
args = pickle.loads(request.data)
args = safe_loads(request.data)
regex = r"(^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$)"
if not re.search(regex, args["username"].lower()):
return ("Bad Request", 400)
Expand Down Expand Up @@ -355,7 +357,7 @@ def users_email_code(user_id):
@bp.route("/users/<int:user_id>/confirm", methods=("POST",))
@encrypted
def users_confirm(user_id):
args = pickle.loads(request.data)
args = safe_loads(request.data)
user = User.query.get(user_id)
g.user = user
if not user:
Expand All @@ -372,7 +374,7 @@ def users_confirm(user_id):
@bp.route("/users/<int:user_id>", methods=("POST",))
@encrypted
def users_update(user_id):
args = pickle.loads(request.data)
args = safe_loads(request.data)
user = User.query.get(user_id)
g.user = user
if not user:
Expand Down Expand Up @@ -407,7 +409,7 @@ def users_free_diagram(user_id):
return ("Not Found", 404)
if g.user.id != user.id:
return ("Unauthorized", 401)
args = pickle.loads(request.data)
args = safe_loads(request.data)
session = Session.query.filter_by(token=args["session"]).first()
if not session:
return ("Unauthorized", 401)
Expand Down Expand Up @@ -448,7 +450,7 @@ def users_free_diagram(user_id):
@encrypted
def sessions_init():
"""Called when the MainWindow starts up."""
args = pickle.loads(request.data)
args = safe_loads(request.data)
if args.get("token"):
session = Session.query.filter_by(token=args.get("token")).first()
if not session:
Expand Down Expand Up @@ -484,7 +486,7 @@ def sessions_init():
def sessions_login():
import os

args = pickle.loads(request.data)
args = safe_loads(request.data)
password = args.get("password")
username = args.get("username")

Expand Down Expand Up @@ -564,7 +566,7 @@ def sessions_web_auth_token():
@bp.route("/policies", methods=("GET",))
@encrypted
def policies_policies():
args = pickle.loads(request.data)
args = safe_loads(request.data)
session = Session.query.filter_by(token=args["session"]).first()
if not session:
return ("Unauthorized", 401)
Expand All @@ -581,7 +583,7 @@ def policies_policies():
@bp.route("/licenses", methods=("POST",))
@encrypted
def licenses_purchase():
args = pickle.loads(request.data)
args = safe_loads(request.data)
session = Session.query.filter_by(token=args["session"]).first()
if not session:
return ("Unauthorized", 401)
Expand Down Expand Up @@ -663,7 +665,7 @@ def licenses_purchase():
@encrypted
@deprecated
def licenses_verify():
args = pickle.loads(request.data)
args = safe_loads(request.data)
ret = {"licenses": []}
for entry in args["licenses"]:
license = License.query.filter_by(key=entry["key"]).first()
Expand All @@ -677,7 +679,7 @@ def licenses_verify():
@bp.route("/licenses/<key>", methods=("GET",))
@encrypted
def licenses_get(key):
args = pickle.loads(request.data)
args = safe_loads(request.data)
session = Session.query.filter_by(token=args["session"]).first()
if not session:
return ("Unauthorized", 401)
Expand All @@ -697,7 +699,7 @@ def licenses_get(key):
@bp.route("/licenses/<key>/cancel", methods=("POST",))
@encrypted
def licenses_cancel(key):
args = pickle.loads(request.data)
args = safe_loads(request.data)
session = Session.query.filter_by(token=args["session"]).first()
if not session:
return ("Unauthorized", 401)
Expand Down Expand Up @@ -727,7 +729,7 @@ def licenses_cancel(key):
@bp.route("/licenses/<key>/import", methods=("POST",))
@encrypted
def licenses_import(key):
args = pickle.loads(request.data)
args = safe_loads(request.data)
session = Session.query.filter_by(token=args["session"]).first()
if not session:
return ("Unauthorized", 401)
Expand Down Expand Up @@ -767,7 +769,7 @@ def licenses_import(key):
@bp.route("/machines/<code>", methods=("GET", "POST", "DELETE"))
@encrypted
def machines_machine(code):
args = pickle.loads(request.data)
args = safe_loads(request.data)
session = Session.query.filter_by(token=args["session"]).first()
if not session:
return ("Unauthorized", 401)
Expand Down Expand Up @@ -799,7 +801,7 @@ def machines_machine(code):
@bp.route("/activations", methods=("POST",))
@encrypted
def activations_create():
args = pickle.loads(request.data)
args = safe_loads(request.data)
session = Session.query.filter_by(token=args["session"]).first()
if not session:
return ("Unauthorized", 401)
Expand Down Expand Up @@ -832,7 +834,7 @@ def activations_create():
@bp.route("/activations/<id>", methods=("DELETE",))
@encrypted
def activations_activation(id):
args = pickle.loads(request.data)
args = safe_loads(request.data)
session = Session.query.filter_by(token=args["session"]).first()
if not session:
return ("Unauthorized", 401)
Expand All @@ -856,7 +858,7 @@ def activations_activation(id):
@bp.route("/access_rights/<int:id>", methods=("PATCH", "DELETE"))
@encrypted
def access_right(id=None):
args = pickle.loads(request.data)
args = safe_loads(request.data)
session = Session.query.filter_by(token=args["session"]).first()
if not session:
return ("Unauthorized", 401)
Expand Down Expand Up @@ -885,7 +887,7 @@ def access_right(id=None):
def copilot_chat(conversation_id: int = None):
from btcopilot.pro.copilot import Event

args = pickle.loads(request.data)
args = safe_loads(request.data)
session = Session.query.filter_by(token=args["session"]).first()
if not session:
return ("Unauthorized", 401)
Expand Down
136 changes: 136 additions & 0 deletions btcopilot/pro/safe_pickle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""Restricted pickle deserialization to prevent RCE attacks.

The Pro app uses pickle as its wire format. Standard pickle.loads() will
execute arbitrary code embedded in the payload — a textbook RCE vector.

This module provides restricted unpicklers that only allow known-safe types:

- safe_loads(): For untrusted client request data (builtins only)
- safe_loads_diagram(): For diagram blobs that may contain PyQt5.QtCore types
"""

import io
import pickle
import logging

_log = logging.getLogger(__name__)

# Types that are safe to unpickle from any source
_SAFE_BUILTINS = frozenset(
{
"dict",
"list",
"set",
"frozenset",
"tuple",
"bytes",
"bytearray",
"str",
"int",
"float",
"bool",
"complex",
"slice",
"range",
"type",
}
)

# PyQt5.QtCore types used in diagram data (QDate, QDateTime, QPointF, etc.)
# sip._unpickle_type is used internally by PyQt5 for pickle reconstruction
_PYQT5_ALLOWED_MODULES = frozenset(
{
"PyQt5.QtCore",
"sip",
}
)

# datetime types that may appear in pickled data
_DATETIME_ALLOWED = frozenset(
{
"date",
"datetime",
"time",
"timedelta",
"timezone",
}
)

# collections types (e.g. OrderedDict) that may appear
_COLLECTIONS_ALLOWED = frozenset(
{
"OrderedDict",
}
)


class _RestrictedUnpickler(pickle.Unpickler):
"""Unpickler that only allows safe builtin types.

Use for deserializing untrusted client request data where only
basic Python types (dict, list, str, int, etc.) are expected.
"""

def find_class(self, module: str, name: str) -> type:
if module == "builtins" and name in _SAFE_BUILTINS:
return getattr(__import__(module), name)
if module == "datetime" and name in _DATETIME_ALLOWED:
import datetime

return getattr(datetime, name)
if module == "collections" and name in _COLLECTIONS_ALLOWED:
import collections

return getattr(collections, name)
raise pickle.UnpicklingError(
f"Blocked unpickling of {module}.{name} — "
f"only safe builtins are allowed in request data"
)


class _DiagramUnpickler(pickle.Unpickler):
"""Unpickler that allows safe builtins + PyQt5.QtCore types.

Use for deserializing diagram data blobs stored in the database,
which may contain QDate, QDateTime, QPointF, etc.
"""

def find_class(self, module: str, name: str) -> type:
if module == "builtins" and name in _SAFE_BUILTINS:
return getattr(__import__(module), name)
if module == "datetime" and name in _DATETIME_ALLOWED:
import datetime

return getattr(datetime, name)
if module == "collections" and name in _COLLECTIONS_ALLOWED:
import collections

return getattr(collections, name)
if module in _PYQT5_ALLOWED_MODULES:
import importlib

mod = importlib.import_module(module)
return getattr(mod, name)
raise pickle.UnpicklingError(
f"Blocked unpickling of {module}.{name} — "
f"only safe builtins and PyQt5.QtCore types are allowed in diagram data"
)


def safe_loads(data: bytes):
"""Safely deserialize pickle data from untrusted client requests.

Only allows basic Python types (dict, list, str, int, float, bool, etc.).
Raises pickle.UnpicklingError if the payload contains disallowed types.
"""
return _RestrictedUnpickler(io.BytesIO(data)).load()


def safe_loads_diagram(data: bytes):
"""Safely deserialize diagram pickle blobs (database storage).

Allows basic Python types plus PyQt5.QtCore types (QDate, QDateTime, etc.)
that are legitimately used in diagram data.
Raises pickle.UnpicklingError if the payload contains disallowed types.
"""
return _DiagramUnpickler(io.BytesIO(data)).load()
Loading