Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 22 additions & 37 deletions src/expr/src/scalar/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ use mz_repr::adt::range::Range;
use mz_repr::adt::regex::Regex;
use mz_repr::adt::timestamp::{CheckedTimestamp, TimestampLike};
use mz_repr::{
ArrayRustType, Datum, DatumList, DatumMap, DatumType, ExcludeNull, Row, RowArena,
SqlScalarType, strconv,
ArrayRustType, ByteStr, ByteString, Datum, DatumList, DatumMap, DatumType, ExcludeNull, Row,
RowArena, SqlScalarType, strconv,
};
use mz_sql_parser::ast::display::FormatMode;
use mz_sql_pretty::{PrettyConfig, pretty_str};
Expand Down Expand Up @@ -329,7 +329,7 @@ fn round_numeric_binary(a: OrderedDecimal<Numeric>, mut b: i32) -> Result<Numeri
}

#[sqlfunc(sqlname = "convert_from", propagates_nulls = true)]
fn convert_from<'a>(a: &'a [u8], b: &str) -> Result<&'a str, EvalError> {
fn convert_from<'a>(a: &'a ByteStr, b: &str) -> Result<&'a str, EvalError> {
// Convert PostgreSQL-style encoding names[1] to WHATWG-style encoding names[2],
// which the encoding library uses[3].
// [1]: https://www.postgresql.org/docs/9.5/multibyte.html
Expand All @@ -342,7 +342,7 @@ fn convert_from<'a>(a: &'a [u8], b: &str) -> Result<&'a str, EvalError> {
return Err(EvalError::InvalidEncodingName(encoding_name));
}

match str::from_utf8(a) {
match str::from_utf8(a.into()) {
Ok(from) => Ok(from),
Err(e) => Err(EvalError::InvalidByteSequence {
byte_sequence: e.to_string().into(),
Expand All @@ -352,24 +352,24 @@ fn convert_from<'a>(a: &'a [u8], b: &str) -> Result<&'a str, EvalError> {
}

#[sqlfunc]
fn encode(bytes: &[u8], format: &str) -> Result<String, EvalError> {
fn encode(bytes: &ByteStr, format: &str) -> Result<String, EvalError> {
let format = encoding::lookup_format(format)?;
Ok(format.encode(bytes))
Ok(format.encode(&*bytes))
}

#[sqlfunc]
fn decode(string: &str, format: &str) -> Result<Vec<u8>, EvalError> {
fn decode(string: &str, format: &str) -> Result<ByteString, EvalError> {
let format = encoding::lookup_format(format)?;
let out = format.decode(string)?;
if out.len() > MAX_STRING_FUNC_RESULT_BYTES {
Err(EvalError::LengthTooLarge)
} else {
Ok(out)
Ok(out.into())
}
}

#[sqlfunc(sqlname = "length", propagates_nulls = true)]
fn encoded_bytes_char_length(a: &[u8], b: &str) -> Result<i32, EvalError> {
fn encoded_bytes_char_length(a: &ByteStr, b: &str) -> Result<i32, EvalError> {
// Convert PostgreSQL-style encoding names[1] to WHATWG-style encoding names[2],
// which the encoding library uses[3].
// [1]: https://www.postgresql.org/docs/9.5/multibyte.html
Expand All @@ -382,7 +382,7 @@ fn encoded_bytes_char_length(a: &[u8], b: &str) -> Result<i32, EvalError> {
None => return Err(EvalError::InvalidEncodingName(encoding_name)),
};

let decoded_string = match enc.decode(a, DecoderTrap::Strict) {
let decoded_string = match enc.decode(&*a, DecoderTrap::Strict) {
Ok(s) => s,
Err(e) => {
return Err(EvalError::InvalidByteSequence {
Expand Down Expand Up @@ -1378,7 +1378,7 @@ fn power_numeric(mut a: Numeric, b: Numeric) -> Result<Numeric, EvalError> {
}

#[sqlfunc(propagates_nulls = true)]
fn get_bit(bytes: &[u8], index: i32) -> Result<i32, EvalError> {
fn get_bit(bytes: &ByteStr, index: i32) -> Result<i32, EvalError> {
let err = EvalError::IndexOutOfRange {
provided: index,
valid_end: i32::try_from(bytes.len().saturating_mul(8)).unwrap() - 1,
Expand All @@ -1398,7 +1398,7 @@ fn get_bit(bytes: &[u8], index: i32) -> Result<i32, EvalError> {
}

#[sqlfunc(propagates_nulls = true)]
fn get_byte(bytes: &[u8], index: i32) -> Result<i32, EvalError> {
fn get_byte(bytes: &ByteStr, index: i32) -> Result<i32, EvalError> {
let err = EvalError::IndexOutOfRange {
provided: index,
valid_end: i32::try_from(bytes.len()).unwrap() - 1,
Expand All @@ -1410,8 +1410,8 @@ fn get_byte(bytes: &[u8], index: i32) -> Result<i32, EvalError> {
}

#[sqlfunc(sqlname = "constant_time_compare_bytes", propagates_nulls = true)]
pub fn constant_time_eq_bytes(a: &[u8], b: &[u8]) -> bool {
bool::from(a.ct_eq(b))
pub fn constant_time_eq_bytes(a: &ByteStr, b: &ByteStr) -> bool {
bool::from(a.ct_eq(&*b))
}

#[sqlfunc(sqlname = "constant_time_compare_strings", propagates_nulls = true)]
Expand Down Expand Up @@ -3002,33 +3002,18 @@ fn list_remove<'a>(a: DatumList<'a>, b: Datum<'a>, temp_storage: &'a RowArena) -
})
}

#[sqlfunc(
output_type = "Vec<u8>",
sqlname = "digest",
propagates_nulls = true,
introduces_nulls = false
)]
fn digest_string<'a>(a: &str, b: &str, temp_storage: &'a RowArena) -> Result<Datum<'a>, EvalError> {
#[sqlfunc(sqlname = "digest")]
fn digest_string(a: &str, b: &str) -> Result<ByteString, EvalError> {
let to_digest = a.as_bytes();
digest_inner(to_digest, b, temp_storage)
digest_inner(to_digest, b)
}

#[sqlfunc(
output_type = "Vec<u8>",
sqlname = "digest",
propagates_nulls = true,
introduces_nulls = false
)]
fn digest_bytes<'a>(a: &[u8], b: &str, temp_storage: &'a RowArena) -> Result<Datum<'a>, EvalError> {
let to_digest = a;
digest_inner(to_digest, b, temp_storage)
#[sqlfunc(sqlname = "digest")]
fn digest_bytes(to_digest: &ByteStr, b: &str) -> Result<ByteString, EvalError> {
digest_inner(&*to_digest, b)
}

fn digest_inner<'a>(
bytes: &[u8],
digest_fn: &str,
temp_storage: &'a RowArena,
) -> Result<Datum<'a>, EvalError> {
fn digest_inner(bytes: &[u8], digest_fn: &str) -> Result<ByteString, EvalError> {
let bytes = match digest_fn {
"md5" => Md5::digest(bytes).to_vec(),
"sha1" => Sha1::digest(bytes).to_vec(),
Expand All @@ -3038,7 +3023,7 @@ fn digest_inner<'a>(
"sha512" => Sha512::digest(bytes).to_vec(),
other => return Err(EvalError::InvalidHashAlgorithm(other.into())),
};
Ok(Datum::Bytes(temp_storage.push_bytes(bytes)))
Ok(bytes.into())
}

#[sqlfunc(
Expand Down
34 changes: 17 additions & 17 deletions src/expr/src/scalar/func/impls/byte.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

use mz_expr_derive::sqlfunc;
use mz_ore::cast::CastFrom;
use mz_repr::strconv;
use mz_repr::{ByteStr, strconv};

use crate::EvalError;

Expand All @@ -18,56 +18,56 @@ use crate::EvalError;
preserves_uniqueness = true,
inverse = to_unary!(super::CastStringToBytes)
)]
fn cast_bytes_to_string(a: &[u8]) -> String {
fn cast_bytes_to_string(a: &ByteStr) -> String {
let mut buf = String::new();
strconv::format_bytes(&mut buf, a);
strconv::format_bytes(&mut buf, &*a);
buf
}

#[sqlfunc(sqlname = "crc32_bytes")]
fn crc32_bytes<'a>(a: &'a [u8]) -> u32 {
fn crc32_bytes(a: &ByteStr) -> u32 {
crc32fast::hash(a)
}

#[sqlfunc(sqlname = "crc32_string")]
fn crc32_string<'a>(a: &'a str) -> u32 {
crc32_bytes(a.as_bytes())
fn crc32_string(a: &str) -> u32 {
crc32_bytes(a.as_bytes().into())
}

#[sqlfunc(sqlname = "kafka_murmur2_bytes")]
fn kafka_murmur2_bytes<'a>(a: &'a [u8]) -> i32 {
i32::from_ne_bytes((murmur2::murmur2(a, murmur2::KAFKA_SEED) & 0x7fffffff).to_ne_bytes())
fn kafka_murmur2_bytes(a: &ByteStr) -> i32 {
i32::from_ne_bytes((murmur2::murmur2(a.into(), murmur2::KAFKA_SEED) & 0x7fffffff).to_ne_bytes())
}

#[sqlfunc(sqlname = "kafka_murmur2_string")]
fn kafka_murmur2_string<'a>(a: &'a str) -> i32 {
kafka_murmur2_bytes(a.as_bytes())
fn kafka_murmur2_string(a: &str) -> i32 {
kafka_murmur2_bytes(a.as_bytes().into())
}

#[sqlfunc(sqlname = "seahash_bytes")]
fn seahash_bytes<'a>(a: &'a [u8]) -> u64 {
seahash::hash(a)
fn seahash_bytes(a: &ByteStr) -> u64 {
seahash::hash(a.into())
}

#[sqlfunc(sqlname = "seahash_string")]
fn seahash_string<'a>(a: &'a str) -> u64 {
seahash_bytes(a.as_bytes())
fn seahash_string(a: &str) -> u64 {
seahash_bytes(a.as_bytes().into())
}

#[sqlfunc(sqlname = "bit_count")]
fn bit_count_bytes<'a>(a: &'a [u8]) -> Result<i64, EvalError> {
fn bit_count_bytes(a: &ByteStr) -> Result<i64, EvalError> {
let count: u64 = a.iter().map(|b| u64::cast_from(b.count_ones())).sum();
i64::try_from(count).or_else(|_| Err(EvalError::Int64OutOfRange(count.to_string().into())))
}

#[sqlfunc(sqlname = "bit_length")]
fn bit_length_bytes<'a>(a: &'a [u8]) -> Result<i32, EvalError> {
fn bit_length_bytes(a: &ByteStr) -> Result<i32, EvalError> {
let val = a.len() * 8;
i32::try_from(val).or_else(|_| Err(EvalError::Int32OutOfRange(val.to_string().into())))
}

#[sqlfunc(sqlname = "octet_length")]
fn byte_length_bytes<'a>(a: &'a [u8]) -> Result<i32, EvalError> {
fn byte_length_bytes(a: &ByteStr) -> Result<i32, EvalError> {
let val = a.len();
i32::try_from(val).or_else(|_| Err(EvalError::Int32OutOfRange(val.to_string().into())))
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
---
source: src/expr/src/scalar/func/impls/byte.rs
expression: "#[sqlfunc(sqlname = \"bit_count\")]\nfn bit_count_bytes<'a>(a: &'a [u8]) -> Result<i64, EvalError> {\n let count: u64 = a.iter().map(|b| u64::cast_from(b.count_ones())).sum();\n i64::try_from(count)\n .or_else(|_| Err(EvalError::Int64OutOfRange(count.to_string().into())))\n}\n"
expression: "#[sqlfunc(sqlname = \"bit_count\")]\nfn bit_count_bytes(a: &ByteStr) -> Result<i64, EvalError> {\n let count: u64 = a.iter().map(|b| u64::cast_from(b.count_ones())).sum();\n i64::try_from(count)\n .or_else(|_| Err(EvalError::Int64OutOfRange(count.to_string().into())))\n}\n"
---
#[derive(
proptest_derive::Arbitrary,
Expand All @@ -17,7 +17,7 @@ expression: "#[sqlfunc(sqlname = \"bit_count\")]\nfn bit_count_bytes<'a>(a: &'a
)]
pub struct BitCountBytes;
impl<'a> crate::func::EagerUnaryFunc<'a> for BitCountBytes {
type Input = &'a [u8];
type Input = &'a ByteStr;
type Output = Result<i64, EvalError>;
fn call(&self, a: Self::Input) -> Self::Output {
bit_count_bytes(a)
Expand All @@ -35,7 +35,7 @@ impl std::fmt::Display for BitCountBytes {
f.write_str("bit_count")
}
}
fn bit_count_bytes<'a>(a: &'a [u8]) -> Result<i64, EvalError> {
fn bit_count_bytes(a: &ByteStr) -> Result<i64, EvalError> {
let count: u64 = a.iter().map(|b| u64::cast_from(b.count_ones())).sum();
i64::try_from(count)
.or_else(|_| Err(EvalError::Int64OutOfRange(count.to_string().into())))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
---
source: src/expr/src/scalar/func/impls/byte.rs
expression: "#[sqlfunc(sqlname = \"bit_length\")]\nfn bit_length_bytes<'a>(a: &'a [u8]) -> Result<i32, EvalError> {\n let val = a.len() * 8;\n i32::try_from(val)\n .or_else(|_| Err(EvalError::Int32OutOfRange(val.to_string().into())))\n}\n"
expression: "#[sqlfunc(sqlname = \"bit_length\")]\nfn bit_length_bytes(a: &ByteStr) -> Result<i32, EvalError> {\n let val = a.len() * 8;\n i32::try_from(val)\n .or_else(|_| Err(EvalError::Int32OutOfRange(val.to_string().into())))\n}\n"
---
#[derive(
proptest_derive::Arbitrary,
Expand All @@ -17,7 +17,7 @@ expression: "#[sqlfunc(sqlname = \"bit_length\")]\nfn bit_length_bytes<'a>(a: &'
)]
pub struct BitLengthBytes;
impl<'a> crate::func::EagerUnaryFunc<'a> for BitLengthBytes {
type Input = &'a [u8];
type Input = &'a ByteStr;
type Output = Result<i32, EvalError>;
fn call(&self, a: Self::Input) -> Self::Output {
bit_length_bytes(a)
Expand All @@ -35,7 +35,7 @@ impl std::fmt::Display for BitLengthBytes {
f.write_str("bit_length")
}
}
fn bit_length_bytes<'a>(a: &'a [u8]) -> Result<i32, EvalError> {
fn bit_length_bytes(a: &ByteStr) -> Result<i32, EvalError> {
let val = a.len() * 8;
i32::try_from(val)
.or_else(|_| Err(EvalError::Int32OutOfRange(val.to_string().into())))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
---
source: src/expr/src/scalar/func/impls/byte.rs
expression: "#[sqlfunc(sqlname = \"octet_length\")]\nfn byte_length_bytes<'a>(a: &'a [u8]) -> Result<i32, EvalError> {\n let val = a.len();\n i32::try_from(val)\n .or_else(|_| Err(EvalError::Int32OutOfRange(val.to_string().into())))\n}\n"
expression: "#[sqlfunc(sqlname = \"octet_length\")]\nfn byte_length_bytes(a: &ByteStr) -> Result<i32, EvalError> {\n let val = a.len();\n i32::try_from(val)\n .or_else(|_| Err(EvalError::Int32OutOfRange(val.to_string().into())))\n}\n"
---
#[derive(
proptest_derive::Arbitrary,
Expand All @@ -17,7 +17,7 @@ expression: "#[sqlfunc(sqlname = \"octet_length\")]\nfn byte_length_bytes<'a>(a:
)]
pub struct ByteLengthBytes;
impl<'a> crate::func::EagerUnaryFunc<'a> for ByteLengthBytes {
type Input = &'a [u8];
type Input = &'a ByteStr;
type Output = Result<i32, EvalError>;
fn call(&self, a: Self::Input) -> Self::Output {
byte_length_bytes(a)
Expand All @@ -35,7 +35,7 @@ impl std::fmt::Display for ByteLengthBytes {
f.write_str("octet_length")
}
}
fn byte_length_bytes<'a>(a: &'a [u8]) -> Result<i32, EvalError> {
fn byte_length_bytes(a: &ByteStr) -> Result<i32, EvalError> {
let val = a.len();
i32::try_from(val)
.or_else(|_| Err(EvalError::Int32OutOfRange(val.to_string().into())))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
---
source: src/expr/src/scalar/func/impls/byte.rs
expression: "#[sqlfunc(\n sqlname = \"bytea_to_text\",\n preserves_uniqueness = true,\n inverse = to_unary!(super::CastStringToBytes)\n)]\nfn cast_bytes_to_string(a: &[u8]) -> String {\n let mut buf = String::new();\n strconv::format_bytes(&mut buf, a);\n buf\n}\n"
expression: "#[sqlfunc(\n sqlname = \"bytea_to_text\",\n preserves_uniqueness = true,\n inverse = to_unary!(super::CastStringToBytes)\n)]\nfn cast_bytes_to_string(a: &ByteStr) -> String {\n let mut buf = String::new();\n strconv::format_bytes(&mut buf, &*a);\n buf\n}\n"
---
#[derive(
proptest_derive::Arbitrary,
Expand All @@ -17,7 +17,7 @@ expression: "#[sqlfunc(\n sqlname = \"bytea_to_text\",\n preserves_uniquen
)]
pub struct CastBytesToString;
impl<'a> crate::func::EagerUnaryFunc<'a> for CastBytesToString {
type Input = &'a [u8];
type Input = &'a ByteStr;
type Output = String;
fn call(&self, a: Self::Input) -> Self::Output {
cast_bytes_to_string(a)
Expand All @@ -41,8 +41,8 @@ impl std::fmt::Display for CastBytesToString {
f.write_str("bytea_to_text")
}
}
fn cast_bytes_to_string(a: &[u8]) -> String {
fn cast_bytes_to_string(a: &ByteStr) -> String {
let mut buf = String::new();
strconv::format_bytes(&mut buf, a);
strconv::format_bytes(&mut buf, &*a);
buf
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
---
source: src/expr/src/scalar/func/impls/byte.rs
expression: "#[sqlfunc(sqlname = \"crc32_bytes\")]\nfn crc32_bytes<'a>(a: &'a [u8]) -> u32 {\n crc32fast::hash(a)\n}\n"
expression: "#[sqlfunc(sqlname = \"crc32_bytes\")]\nfn crc32_bytes(a: &ByteStr) -> u32 {\n crc32fast::hash(a)\n}\n"
---
#[derive(
proptest_derive::Arbitrary,
Expand All @@ -17,7 +17,7 @@ expression: "#[sqlfunc(sqlname = \"crc32_bytes\")]\nfn crc32_bytes<'a>(a: &'a [u
)]
pub struct Crc32Bytes;
impl<'a> crate::func::EagerUnaryFunc<'a> for Crc32Bytes {
type Input = &'a [u8];
type Input = &'a ByteStr;
type Output = u32;
fn call(&self, a: Self::Input) -> Self::Output {
crc32_bytes(a)
Expand All @@ -35,6 +35,6 @@ impl std::fmt::Display for Crc32Bytes {
f.write_str("crc32_bytes")
}
}
fn crc32_bytes<'a>(a: &'a [u8]) -> u32 {
fn crc32_bytes(a: &ByteStr) -> u32 {
crc32fast::hash(a)
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
---
source: src/expr/src/scalar/func/impls/byte.rs
expression: "#[sqlfunc(sqlname = \"crc32_string\")]\nfn crc32_string<'a>(a: &'a str) -> u32 {\n crc32_bytes(a.as_bytes())\n}\n"
expression: "#[sqlfunc(sqlname = \"crc32_string\")]\nfn crc32_string(a: &str) -> u32 {\n crc32_bytes(a.as_bytes().into())\n}\n"
---
#[derive(
proptest_derive::Arbitrary,
Expand Down Expand Up @@ -35,6 +35,6 @@ impl std::fmt::Display for Crc32String {
f.write_str("crc32_string")
}
}
fn crc32_string<'a>(a: &'a str) -> u32 {
crc32_bytes(a.as_bytes())
fn crc32_string(a: &str) -> u32 {
crc32_bytes(a.as_bytes().into())
}
Loading