@@ -3,16 +3,17 @@ package main
33import (
44 "fmt"
55 "io/ioutil"
6+ "net/url"
67 "os"
8+ "strings"
79 "testing"
810
911 "github.com/stretchr/testify/assert"
1012)
1113
1214const (
13- dbFile = "db/data.sqlite.db"
14- dbUrl = "sqlite3://" + dbFile
15- schemaFile = "./db/tables.sqlite3.sql"
15+ dbURL = "sqlite3://:memory:"
16+ schemaFileTemplate = "./db/tables.%s.sql"
1617)
1718
1819var populator * Populator
@@ -34,18 +35,38 @@ func showErrAndExitTest(err error) {
3435}
3536
3637func TestMain (m * testing.M ) {
37- os .Remove (dbFile )
3838 var err error
39- if populator , err = NewPopulator (dbUrl ); err != nil {
39+
40+ connectionURL := dbURL
41+ if envURL := os .Getenv ("DATABASE_URL" ); envURL != "" {
42+ connectionURL = envURL
43+ }
44+
45+ if populator , err = NewPopulator (connectionURL ); err != nil {
4046 showErrAndExitTest (err )
4147 }
42- createTablesStmt , err := ioutil .ReadFile (schemaFile )
48+
49+ uri , err := url .Parse (connectionURL )
4350 if err != nil {
4451 showErrAndExitTest (err )
4552 }
46- if _ , err = populator .DB .Exec (string (createTablesStmt )); err != nil {
53+
54+ schemaFile := fmt .Sprintf (schemaFileTemplate , uri .Scheme )
55+ createTablesStmt , err := ioutil .ReadFile (schemaFile )
56+ if err != nil {
4757 showErrAndExitTest (err )
4858 }
59+
60+ statements := strings .Split (string (createTablesStmt ), ";" )
61+ for _ , statement := range statements {
62+ if strings .TrimSpace (statement ) == "" {
63+ continue
64+ }
65+ if _ , err = populator .DB .Exec (statement ); err != nil {
66+ showErrAndExitTest (err )
67+ }
68+ }
69+
4970 os .Exit (m .Run ())
5071}
5172
0 commit comments