123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153 |
- package db
- import (
- "database/sql"
- "errors"
- "os"
- "reflect"
- "regexp"
- "git.clearsky.net.au/cody/gex.git/utils"
- _ "modernc.org/sqlite"
- )
- var Conn *sql.DB
- func Connect(fp string) error {
- _, err := os.Stat(fp)
- if err != nil {
- utils.Err(err)
- os.Exit(1)
- return err
- }
- if Conn, err = sql.Open("sqlite", fp); err != nil {
- utils.Err(err)
- os.Exit(1)
- return err
- }
- Conn.Exec(`
- PRAGMA journal_mode=ON
- PRAGMA journal_mode=WAL;
- PRAGMA synchronous=ON
- PRAGMA auto_vacuum=ON
- PRAGMA foreign_keys=ON
- `)
- return nil
- }
- func Get(model interface{}, query string, params ...any) error {
- if reflect.TypeOf(model).Kind() != reflect.Pointer {
- err := errors.New("model must be a pointer")
- utils.Err(err)
- return err
- }
- if reflect.TypeOf(model).Elem().Kind() != reflect.Struct {
- err := errors.New("model must be a struct")
- utils.Err(err)
- return err
- }
- p := []any{}
- p = append(p, params...)
-
- re := regexp.MustCompile(`:\w+`)
- query = re.ReplaceAllString(query, "?")
- err := Conn.QueryRow(query, p...).Scan(StructPointers(model)...)
- if err != nil {
- err = errors.New("SQL " + err.Error())
- utils.Err(err)
- return err
- }
- return nil
- }
- func All(model interface{}, query string, params ...any) error {
- p := []any{}
- p = append(p, params...)
- if reflect.TypeOf(model).Kind() != reflect.Pointer {
- err := errors.New("db.All(): model must be a pointer")
- utils.Err(err)
- return err
- }
- slicePtr := reflect.ValueOf(model)
- sliceVal := slicePtr.Elem()
- if sliceVal.Kind() != reflect.Slice {
- err := errors.New("db.All(): model must be a slice of model")
- utils.Err(err)
- return err
- }
- modelType := sliceVal.Type().Elem()
-
- re := regexp.MustCompile(`:\w+`)
- query = re.ReplaceAllString(query, "?")
- rows, err := Conn.Query(query, p...)
- if err != nil {
- err := errors.New("db.All(): SQL " + err.Error())
- utils.Err(err)
- return err
- }
- var set []reflect.Value
- for rows.Next() {
- tmpModel := reflect.New(modelType)
- err = rows.Scan(StructPointers(tmpModel.Interface())...)
- set = append(set, tmpModel.Elem())
- }
- if err != nil {
- utils.Err(err)
- return err
- }
- rows.Close()
- sliceVal.Set(reflect.Append(sliceVal, set...))
- return nil
- }
- func Close() {
- Conn.Close()
- }
- func StructPointers(s interface{}) []interface{} {
- model := reflect.ValueOf(s)
- modelElem := model.Elem()
- numFields := modelElem.NumField()
- var Ptrs []interface{}
- for i := 0; i < numFields; i++ {
- field := modelElem.Field(i)
- tag := modelElem.Type().Field(i).Tag
- if tag != "" {
- continue
- }
-
- if field.Kind() == reflect.Struct {
- Ptrs = append(Ptrs, StructPointers(field.Addr().Interface())...)
- continue
- }
- Ptrs = append(Ptrs, field.Addr().Interface())
- }
- return Ptrs
- }
|