db.go 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. package db
  2. import (
  3. "database/sql"
  4. "errors"
  5. "os"
  6. "reflect"
  7. "regexp"
  8. "git.clearsky.net.au/cody/gex.git/utils"
  9. _ "modernc.org/sqlite"
  10. )
  11. var Conn *sql.DB
  12. func Connect(fp string) error {
  13. _, err := os.Stat(fp)
  14. if err != nil {
  15. utils.Err(err)
  16. os.Exit(1)
  17. return err
  18. }
  19. if Conn, err = sql.Open("sqlite", fp); err != nil {
  20. utils.Err(err)
  21. os.Exit(1)
  22. return err
  23. }
  24. Conn.Exec(`
  25. PRAGMA journal_mode=ON
  26. PRAGMA journal_mode=WAL;
  27. PRAGMA synchronous=ON
  28. PRAGMA auto_vacuum=ON
  29. PRAGMA foreign_keys=ON
  30. `)
  31. return nil
  32. }
  33. func Get(model interface{}, query string, params ...any) error {
  34. if reflect.TypeOf(model).Kind() != reflect.Pointer {
  35. err := errors.New("model must be a pointer")
  36. utils.Err(err)
  37. return err
  38. }
  39. if reflect.TypeOf(model).Elem().Kind() != reflect.Struct {
  40. err := errors.New("model must be a struct")
  41. utils.Err(err)
  42. return err
  43. }
  44. p := []any{}
  45. p = append(p, params...)
  46. //update named params to ?
  47. re := regexp.MustCompile(`:\w+`)
  48. query = re.ReplaceAllString(query, "?")
  49. err := Conn.QueryRow(query, p...).Scan(StructPointers(model)...)
  50. if err != nil {
  51. err = errors.New("SQL " + err.Error())
  52. utils.Err(err)
  53. return err
  54. }
  55. return nil
  56. }
  57. func All(model interface{}, query string, params ...any) error {
  58. p := []any{}
  59. p = append(p, params...)
  60. if reflect.TypeOf(model).Kind() != reflect.Pointer {
  61. err := errors.New("db.All(): model must be a pointer")
  62. utils.Err(err)
  63. return err
  64. }
  65. slicePtr := reflect.ValueOf(model)
  66. sliceVal := slicePtr.Elem()
  67. if sliceVal.Kind() != reflect.Slice {
  68. err := errors.New("db.All(): model must be a slice of model")
  69. utils.Err(err)
  70. return err
  71. }
  72. modelType := sliceVal.Type().Elem()
  73. //update named params to ?
  74. re := regexp.MustCompile(`:\w+`)
  75. query = re.ReplaceAllString(query, "?")
  76. rows, err := Conn.Query(query, p...)
  77. if err != nil {
  78. err := errors.New("db.All(): SQL " + err.Error())
  79. utils.Err(err)
  80. return err
  81. }
  82. var set []reflect.Value
  83. for rows.Next() {
  84. tmpModel := reflect.New(modelType)
  85. err = rows.Scan(StructPointers(tmpModel.Interface())...)
  86. set = append(set, tmpModel.Elem())
  87. }
  88. if err != nil {
  89. utils.Err(err)
  90. return err
  91. }
  92. rows.Close()
  93. sliceVal.Set(reflect.Append(sliceVal, set...))
  94. return nil
  95. }
  96. func Close() {
  97. Conn.Close()
  98. }
  99. func StructPointers(s interface{}) []interface{} {
  100. model := reflect.ValueOf(s)
  101. modelElem := model.Elem()
  102. numFields := modelElem.NumField()
  103. var Ptrs []interface{}
  104. for i := 0; i < numFields; i++ {
  105. field := modelElem.Field(i)
  106. tag := modelElem.Type().Field(i).Tag
  107. if tag != "" {
  108. continue
  109. }
  110. // Merge
  111. if field.Kind() == reflect.Struct {
  112. Ptrs = append(Ptrs, StructPointers(field.Addr().Interface())...)
  113. continue
  114. }
  115. Ptrs = append(Ptrs, field.Addr().Interface())
  116. }
  117. return Ptrs
  118. }