Skip to content

Commit b645721

Browse files
authored
IP/CIDR functions. (#246)
1 parent 6c29623 commit b645721

File tree

4 files changed

+219
-18
lines changed

4 files changed

+219
-18
lines changed

ext/ipaddr/ipaddr.go

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
// Package ipaddr provides functions to manipulate IPs and CIDRs.
2+
//
3+
// It provides the following functions:
4+
// - ipcontains(prefix, ip)
5+
// - ipoverlaps(prefix1, prefix2)
6+
// - ipfamily(ip/prefix)
7+
// - iphost(ip/prefix)
8+
// - ipmasklen(prefix)
9+
// - ipnetwork(prefix)
10+
package ipaddr
11+
12+
import (
13+
"errors"
14+
"net/netip"
15+
16+
"github.com/ncruces/go-sqlite3"
17+
)
18+
19+
// Register IP/CIDR functions for a database connection.
20+
func Register(db *sqlite3.Conn) error {
21+
const flags = sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
22+
return errors.Join(
23+
db.CreateFunction("ipcontains", 2, flags, contains),
24+
db.CreateFunction("ipoverlaps", 2, flags, overlaps),
25+
db.CreateFunction("ipfamily", 1, flags, family),
26+
db.CreateFunction("iphost", 1, flags, host),
27+
db.CreateFunction("ipmasklen", 1, flags, masklen),
28+
db.CreateFunction("ipnetwork", 1, flags, network))
29+
}
30+
31+
func contains(ctx sqlite3.Context, arg ...sqlite3.Value) {
32+
prefix, err := netip.ParsePrefix(arg[0].Text())
33+
if err != nil {
34+
ctx.ResultError(err)
35+
return // notest
36+
}
37+
addr, err := netip.ParseAddr(arg[1].Text())
38+
if err != nil {
39+
ctx.ResultError(err)
40+
return // notest
41+
}
42+
ctx.ResultBool(prefix.Contains(addr))
43+
}
44+
45+
func overlaps(ctx sqlite3.Context, arg ...sqlite3.Value) {
46+
prefix1, err := netip.ParsePrefix(arg[0].Text())
47+
if err != nil {
48+
ctx.ResultError(err)
49+
return // notest
50+
}
51+
prefix2, err := netip.ParsePrefix(arg[0].Text())
52+
if err != nil {
53+
ctx.ResultError(err)
54+
return // notest
55+
}
56+
ctx.ResultBool(prefix1.Overlaps(prefix2))
57+
}
58+
59+
func family(ctx sqlite3.Context, arg ...sqlite3.Value) {
60+
addr, err := addr(arg[0].Text())
61+
if err != nil {
62+
ctx.ResultError(err)
63+
return // notest
64+
}
65+
switch {
66+
case addr.Is4():
67+
ctx.ResultInt(4)
68+
case addr.Is6():
69+
ctx.ResultInt(6)
70+
}
71+
}
72+
73+
func host(ctx sqlite3.Context, arg ...sqlite3.Value) {
74+
addr, err := addr(arg[0].Text())
75+
if err != nil {
76+
ctx.ResultError(err)
77+
return // notest
78+
}
79+
buf, _ := addr.MarshalText()
80+
ctx.ResultRawText(buf)
81+
}
82+
83+
func masklen(ctx sqlite3.Context, arg ...sqlite3.Value) {
84+
prefix, err := netip.ParsePrefix(arg[0].Text())
85+
if err != nil {
86+
ctx.ResultError(err)
87+
return // notest
88+
}
89+
ctx.ResultInt(prefix.Bits())
90+
}
91+
92+
func network(ctx sqlite3.Context, arg ...sqlite3.Value) {
93+
prefix, err := netip.ParsePrefix(arg[0].Text())
94+
if err != nil {
95+
ctx.ResultError(err)
96+
return // notest
97+
}
98+
buf, _ := prefix.Masked().MarshalText()
99+
ctx.ResultRawText(buf)
100+
}
101+
102+
func addr(text string) (netip.Addr, error) {
103+
addr, err := netip.ParseAddr(text)
104+
if err != nil {
105+
if prefix, err := netip.ParsePrefix(text); err == nil {
106+
return prefix.Addr(), nil
107+
}
108+
if addrpt, err := netip.ParseAddrPort(text); err == nil {
109+
return addrpt.Addr(), nil
110+
}
111+
}
112+
return addr, err
113+
}

ext/ipaddr/ipaddr_test.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
package ipaddr_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/ncruces/go-sqlite3/driver"
7+
_ "github.com/ncruces/go-sqlite3/embed"
8+
"github.com/ncruces/go-sqlite3/ext/ipaddr"
9+
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
10+
"github.com/ncruces/go-sqlite3/vfs/memdb"
11+
)
12+
13+
func TestRegister(t *testing.T) {
14+
t.Parallel()
15+
tmp := memdb.TestDB(t)
16+
17+
db, err := driver.Open(tmp, ipaddr.Register)
18+
if err != nil {
19+
t.Fatal(err)
20+
}
21+
defer db.Close()
22+
23+
var got string
24+
25+
err = db.QueryRow(`SELECT ipfamily('::1')`).Scan(&got)
26+
if err != nil {
27+
t.Fatal(err)
28+
}
29+
if got != "6" {
30+
t.Fatalf("got %s", got)
31+
}
32+
33+
err = db.QueryRow(`SELECT ipfamily('[::1]:80')`).Scan(&got)
34+
if err != nil {
35+
t.Fatal(err)
36+
}
37+
if got != "6" {
38+
t.Fatalf("got %s", got)
39+
}
40+
41+
err = db.QueryRow(`SELECT ipfamily('192.168.1.5/24')`).Scan(&got)
42+
if err != nil {
43+
t.Fatal(err)
44+
}
45+
if got != "4" {
46+
t.Fatalf("got %s", got)
47+
}
48+
49+
err = db.QueryRow(`SELECT iphost('192.168.1.5/24')`).Scan(&got)
50+
if err != nil {
51+
t.Fatal(err)
52+
}
53+
if got != "192.168.1.5" {
54+
t.Fatalf("got %s", got)
55+
}
56+
57+
err = db.QueryRow(`SELECT ipmasklen('192.168.1.5/24')`).Scan(&got)
58+
if err != nil {
59+
t.Fatal(err)
60+
}
61+
if got != "24" {
62+
t.Fatalf("got %s", got)
63+
}
64+
65+
err = db.QueryRow(`SELECT ipnetwork('192.168.1.5/24')`).Scan(&got)
66+
if err != nil {
67+
t.Fatal(err)
68+
}
69+
if got != "192.168.1.0/24" {
70+
t.Fatalf("got %s", got)
71+
}
72+
73+
err = db.QueryRow(`SELECT ipcontains('192.168.1.0/24', '192.168.1.5')`).Scan(&got)
74+
if err != nil {
75+
t.Fatal(err)
76+
}
77+
if got != "1" {
78+
t.Fatalf("got %s", got)
79+
}
80+
81+
err = db.QueryRow(`SELECT ipoverlaps('192.168.1.0/24', '192.168.1.5/32')`).Scan(&got)
82+
if err != nil {
83+
t.Fatal(err)
84+
}
85+
if got != "1" {
86+
t.Fatalf("got %s", got)
87+
}
88+
}

ext/unicode/unicode.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
11
// Package unicode provides an alternative to the SQLite ICU extension.
22
//
33
// Like the [ICU extension], it provides Unicode aware:
4-
// - upper() and lower() functions,
5-
// - LIKE and REGEXP operators,
6-
// - collation sequences.
4+
// - upper() and lower() functions
5+
// - LIKE and REGEXP operators
6+
// - collation sequences
77
//
88
// Like PostgreSQL, it also provides:
9-
// - initcap(),
10-
// - casefold(),
11-
// - normalize(),
12-
// - unaccent().
9+
// - initcap()
10+
// - casefold()
11+
// - normalize()
12+
// - unaccent()
1313
//
1414
// The implementations are not 100% compatible:
15-
// - upper(), lower(), initcap() casefold() use [strings.ToUpper], [strings.ToLower], [strings.Title] and [cases];
16-
// - normalize(), unaccent() use [transform] and [unicode.Mn];
17-
// - the LIKE operator follows [strings.EqualFold] rules;
18-
// - the REGEXP operator uses Go [regexp/syntax];
19-
// - collation sequences use [collate].
15+
// - upper(), lower(), initcap() casefold() use [strings.ToUpper], [strings.ToLower], [strings.Title] and [cases]
16+
// - normalize(), unaccent() use [transform] and [unicode.Mn]
17+
// - the LIKE operator follows [strings.EqualFold] rules
18+
// - the REGEXP operator uses Go [regexp/syntax]
19+
// - collation sequences use [collate]
2020
//
2121
// Expect subtle differences (e.g.) in the handling of Turkish case folding.
2222
//

ext/uuid/uuid.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,17 @@ import (
1818
// Register registers the SQL functions:
1919
//
2020
// - uuid([ version [, domain/namespace, [ id/data ]]]):
21-
// to generate a UUID as a string,
21+
// to generate a UUID as a string
2222
// - uuid_str(u):
23-
// to convert a UUID into a well-formed UUID string,
23+
// to convert a UUID into a well-formed UUID string
2424
// - uuid_blob(u):
25-
// to convert a UUID into a 16-byte blob,
25+
// to convert a UUID into a 16-byte blob
2626
// - uuid_extract_version(u):
27-
// to extract the version of a RFC 4122 UUID,
27+
// to extract the version of a RFC 4122 UUID
2828
// - uuid_extract_timestamp(u):
29-
// to extract the timestamp of a version 1/2/6/7 UUID,
29+
// to extract the timestamp of a version 1/2/6/7 UUID
3030
// - gen_random_uuid(u):
31-
// to generate a version 4 (random) UUID.
31+
// to generate a version 4 (random) UUID
3232
func Register(db *sqlite3.Conn) error {
3333
const flags = sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
3434
return errors.Join(

0 commit comments

Comments
 (0)