db.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. package db
  2. import (
  3. "database/sql"
  4. "errors"
  5. "os"
  6. "reflect"
  7. "git.clearsky.net.au/cody/gex.git/utils"
  8. _ "github.com/mattn/go-sqlite3"
  9. )
  10. var Conn *sql.DB
  11. func Connect(fp string) error {
  12. _, err := os.Stat(fp)
  13. if err != nil {
  14. utils.Err(err)
  15. os.Exit(1)
  16. return err
  17. }
  18. if Conn, err = sql.Open("sqlite3", fp); err != nil {
  19. utils.Err(err)
  20. os.Exit(1)
  21. return err
  22. }
  23. Conn.Exec(`
  24. PRAGMA journal_mode=ON
  25. PRAGMA journal_mode=WAL;
  26. PRAGMA synchronous=ON
  27. PRAGMA auto_vacuum=ON
  28. PRAGMA foreign_keys=ON
  29. `)
  30. return nil
  31. }
  32. func Get(model interface{}, query string, params ...any) error {
  33. if reflect.TypeOf(model).Kind() != reflect.Pointer {
  34. err := errors.New("model must be a pointer")
  35. utils.Err(err)
  36. return err
  37. }
  38. if reflect.TypeOf(model).Elem().Kind() != reflect.Struct {
  39. err := errors.New("model must be a struct")
  40. utils.Err(err)
  41. return err
  42. }
  43. p := []any{}
  44. p = append(p, params...)
  45. err := Conn.QueryRow(query, p...).Scan(StructPointers(model)...)
  46. if err != nil {
  47. err = errors.New("SQL " + err.Error())
  48. utils.Err(err)
  49. return err
  50. }
  51. return nil
  52. }
  53. func All(model interface{}, query string, params ...any) error {
  54. p := []any{}
  55. p = append(p, params...)
  56. if reflect.TypeOf(model).Kind() != reflect.Pointer {
  57. err := errors.New("db.All(): model must be a pointer")
  58. utils.Err(err)
  59. return err
  60. }
  61. slicePtr := reflect.ValueOf(model)
  62. sliceVal := slicePtr.Elem()
  63. if sliceVal.Kind() != reflect.Slice {
  64. err := errors.New("db.All(): model must be a slice of model")
  65. utils.Err(err)
  66. return err
  67. }
  68. modelType := sliceVal.Type().Elem()
  69. rows, err := Conn.Query(query, p...)
  70. if err != nil {
  71. err := errors.New("db.All(): SQL " + err.Error())
  72. utils.Err(err)
  73. return err
  74. }
  75. var set []reflect.Value
  76. for rows.Next() {
  77. tmpModel := reflect.New(modelType)
  78. err = rows.Scan(StructPointers(tmpModel.Interface())...)
  79. set = append(set, tmpModel.Elem())
  80. }
  81. if err != nil {
  82. utils.Err(err)
  83. return err
  84. }
  85. rows.Close()
  86. sliceVal.Set(reflect.Append(sliceVal, set...))
  87. return nil
  88. }
  89. func Close() {
  90. Conn.Close()
  91. }
  92. func StructPointers(s interface{}) []interface{} {
  93. model := reflect.ValueOf(s)
  94. modelElem := model.Elem()
  95. numFields := modelElem.NumField()
  96. var Ptrs []interface{}
  97. for i := 0; i < numFields; i++ {
  98. field := modelElem.Field(i)
  99. tag := modelElem.Type().Field(i).Tag
  100. if tag != "" {
  101. continue
  102. }
  103. // Merge
  104. if field.Kind() == reflect.Struct {
  105. Ptrs = append(Ptrs, StructPointers(field.Addr().Interface())...)
  106. continue
  107. }
  108. Ptrs = append(Ptrs, field.Addr().Interface())
  109. }
  110. return Ptrs
  111. }