query.dis 11 KB


  1. package db
  2. import (
  3. "errors"
  4. "fmt"
  5. "math"
  6. "net/http"
  7. "reflect"
  8. "strconv"
  9. "strings"
  10. "git.clearsky.net.au/cody/gex.git/sec"
  11. )
  12. type Pagination struct {
  13. PrevPage string
  14. NextPage string
  15. Page int
  16. Pages int
  17. PerPage int
  18. Order string
  19. Direction string
  20. Search string
  21. Start int
  22. }
  23. type Query struct {
  24. targetInt interface{}
  25. targetPtr reflect.Value
  26. targetVal reflect.Value
  27. modelType reflect.Type
  28. tblName string
  29. pkName string
  30. Auth *sec.Auth
  31. err error
  32. ready bool
  33. sel string
  34. where string
  35. order string
  36. vals []any
  37. }
  38. func (q *Query) New(targetInt interface{}, auth *sec.Auth) *Query {
  39. q.targetInt = targetInt
  40. q.targetPtr = reflect.ValueOf(q.targetInt)
  41. if q.targetPtr.Kind().String() != "ptr" {
  42. q.err = errors.New("ERROR: db.Query.New(), targetInt must be a pointer")
  43. return q
  44. }
  45. q.targetVal = q.targetPtr.Elem()
  46. q.sel = "*"
  47. // Get the secion
  48. q.Auth = auth
  49. if q.targetVal.Kind().String() == "slice" {
  50. q.newSlice()
  51. return q
  52. }
  53. //Get the type of the model
  54. q.modelType = q.targetVal.Type()
  55. // Get the table name
  56. tblMth := q.targetPtr.MethodByName("Table")
  57. if !tblMth.IsValid() {
  58. q.err = errors.New("ERROR: db.Query.New(), target Table() method must be defined")
  59. return q
  60. }
  61. q.tblName = tblMth.Call(nil)[0].String()
  62. // Get the primary key name
  63. pkMth := q.targetPtr.MethodByName("PK")
  64. if !tblMth.IsValid() {
  65. q.err = errors.New("ERROR: db.Query.New(), target PK() method must be defined")
  66. return q
  67. }
  68. q.pkName = pkMth.Call(nil)[0].String()
  69. q.ready = true
  70. return q
  71. }
  72. func (q *Query) newSlice() {
  73. // Get the type of the model
  74. q.modelType = q.targetVal.Type().Elem()
  75. // create a new model so we can run the table and pk methods
  76. tmpModel := reflect.New(q.modelType)
  77. // Get the table name
  78. tblMth := tmpModel.MethodByName("Table")
  79. if !tblMth.IsValid() {
  80. q.err = errors.New("ERROR: db.Query.New(), target Table() method must be defined")
  81. }
  82. q.tblName = tblMth.Call(nil)[0].String()
  83. // Get the primary key name
  84. pkMth := tmpModel.MethodByName("PK")
  85. if !tblMth.IsValid() {
  86. q.err = errors.New("ERROR: db.Query.New(), target PK() method must be defined")
  87. }
  88. q.pkName = pkMth.Call(nil)[0].String()
  89. q.ready = true
  90. }
  91. func (q *Query) Reset() {
  92. q.sel = "*"
  93. q.where = ""
  94. var new []any
  95. q.vals = new
  96. }
  97. func (q *Query) Select(s string) *Query {
  98. if q.err != nil {
  99. return q
  100. }
  101. if !q.ready {
  102. q.err = errors.New("ERROR: db.Query.Select(), db.Query.New() needs to be ran first")
  103. return q
  104. }
  105. q.sel = s
  106. return q
  107. }
  108. func (q *Query) Where(str string, vals ...any) *Query {
  109. if q.err != nil {
  110. return q
  111. }
  112. if !q.ready {
  113. q.err = errors.New("ERROR: db.Query.Where(), db.Query.New() needs to run first")
  114. return q
  115. }
  116. q.where = fmt.Sprintf("WHERE (%s) ", str)
  117. q.vals = append(q.vals, vals...)
  118. return q
  119. }
  120. // Security check here to ensure str is a field (column)
  121. func (q *Query) Order(str string) *Query {
  122. q.order += "ORDER BY " + str
  123. return q
  124. }
  125. func (q *Query) Create() (int64, error) {
  126. if q.err != nil {
  127. return 0, q.err
  128. }
  129. if !q.ready {
  130. q.err = errors.New("ERROR: db.Query.Create(), db.Query.New() needs to run first")
  131. return 0, q.err
  132. }
  133. if q.targetVal.Kind().String() != "struct" {
  134. return 0, errors.New("ERROR: db.Query.Create() model must be a struct (are you using a slice)")
  135. }
  136. fields, plcHdr, vals := q.getFields(q.targetPtr.Interface())
  137. for i := 0; i < len(vals); i++ {
  138. if vals[i] == "<nil>" {
  139. vals[i] = nil
  140. }
  141. }
  142. query := fmt.Sprintf(
  143. "INSERT INTO %s (%s) VALUES (%s)", q.tblName, strings.Join(fields, ", "), strings.Join(plcHdr, ", "),
  144. )
  145. q.targetVal.FieldByName("sec").Set(reflect.ValueOf(&q.Auth))
  146. createRuleMethod := q.targetPtr.MethodByName("CreateRule")
  147. createRulePassed := createRuleMethod.Call(nil)[0].Bool()
  148. if !createRulePassed {
  149. return 0, errors.New("ERROR: db.Query.Create(), create permissions not passed")
  150. }
  151. res, err := Conn.Exec(query, vals...)
  152. if err != nil {
  153. stringVals := make([]string, len(vals))
  154. for i, v := range vals {
  155. stringVals[i] = fmt.Sprint(v)
  156. }
  157. fmt.Printf("QUERY ERROR: db.Create(), Query: `%s` Values: `%s` Error: %s ", query, strings.Join(stringVals, ","), err.Error())
  158. return 0, err
  159. }
  160. lastId, _ := res.LastInsertId()
  161. q.targetVal.FieldByName(q.pkName).Set(reflect.ValueOf(int(lastId)))
  162. return lastId, nil
  163. }
  164. func (q *Query) Read(id int) error {
  165. if q.err != nil {
  166. return q.err
  167. }
  168. if !q.ready {
  169. q.err = errors.New("ERROR: db.Query.Read(), db.Query.New() needs to run first")
  170. return q.err
  171. }
  172. if q.targetVal.Kind().String() != "struct" {
  173. return errors.New("ERROR: db.Query.Read(), model must be a struct (are you using a slice)")
  174. }
  175. err := q.Where(q.pkName+" = ?", id).Get()
  176. if err != nil {
  177. return err
  178. }
  179. return nil
  180. }
  181. func (q *Query) Update() error {
  182. if q.err != nil {
  183. return q.err
  184. }
  185. if !q.ready {
  186. q.err = errors.New("ERROR: db.Query.Update(), db.Query.New() needs to run first")
  187. return q.err
  188. }
  189. if q.targetVal.Kind().String() != "struct" {
  190. return errors.New("ERROR: db.Query.Update() model must be a struct (are you using a slice)")
  191. }
  192. // Check the existing data from db to check rule
  193. id := q.targetVal.FieldByName(q.pkName)
  194. tmpModel := reflect.New(q.modelType)
  195. tmpQuery := Query{}
  196. q.err = tmpQuery.New(tmpModel.Interface(), q.Auth).Read(int(id.Int()))
  197. if q.err != nil {
  198. return q.err
  199. }
  200. existRuleMethod := tmpModel.MethodByName("UpdateRule")
  201. existRulePassed := existRuleMethod.Call(nil)[0].Bool()
  202. newRuleMethod := q.targetPtr.MethodByName("UpdateRule")
  203. newRulePassed := newRuleMethod.Call(nil)[0].Bool()
  204. if !existRulePassed || !newRulePassed {
  205. return errors.New("ERROR: db.Query.Update() ReadRule not passed)")
  206. }
  207. fields, _, vals := q.getFields(q.targetInt)
  208. for i := 0; i < len(fields); i++ {
  209. fields[i] = fields[i] + "=?"
  210. }
  211. for i := 0; i < len(vals); i++ {
  212. if vals[i] == "<nil>" {
  213. vals[i] = nil
  214. }
  215. }
  216. vals = append(vals, int(id.Int()))
  217. query := fmt.Sprintf(
  218. "UPDATE %s SET %s WHERE %s=?", q.tblName, strings.Join(fields, ", "), q.pkName,
  219. )
  220. _, err := Conn.Exec(query, vals...)
  221. if err != nil {
  222. fmt.Printf("--QUERY ERROR (db.Update()): `%s` %s ", query, err.Error())
  223. fmt.Println("\nQuery Values: ", vals)
  224. q.Reset()
  225. return err
  226. }
  227. return nil
  228. }
  229. func (q *Query) Delete() error {
  230. if q.err != nil {
  231. return q.err
  232. }
  233. if !q.ready {
  234. q.err = errors.New("db.Query.New() needs to run first")
  235. return q.err
  236. }
  237. if q.targetVal.Kind().String() != "struct" {
  238. return errors.New("db.Query.Delete() model must be a struct (are you trying a slice or array?)")
  239. }
  240. // Check the existing data from db to check rule
  241. id := q.targetVal.FieldByName(q.pkName)
  242. tmpModel := reflect.New(q.targetVal.Type())
  243. tmpQuery := Query{}
  244. tmpQuery.New(tmpModel.Interface(), q.Auth).Read(int(id.Int()))
  245. existRuleMethod := tmpModel.MethodByName("DeleteRule")
  246. existRulePassed := existRuleMethod.Call(nil)[0].Bool()
  247. if !existRulePassed {
  248. err := "Update rule failed"
  249. fmt.Println(err)
  250. return errors.New(err)
  251. }
  252. query := fmt.Sprintf("DELETE FROM %s WHERE %s = ?", q.tblName, q.pkName)
  253. _, err := Conn.Exec(query, int(id.Int()))
  254. if err != nil {
  255. return errors.New("db.Query.Get(), " + err.Error())
  256. }
  257. q.Reset()
  258. return nil
  259. }
  260. func (q *Query) Get() error {
  261. q.targetVal.FieldByName("Auth").Set(reflect.ValueOf(q.Auth))
  262. // Create a temp model to store the data
  263. tmpModel := reflect.New(q.modelType)
  264. tmpModelElem := tmpModel.Elem()
  265. tmpModelElem.FieldByName("Auth").Set(reflect.ValueOf(q.Auth))
  266. //
  267. query := fmt.Sprintf("SELECT %s FROM %s %s %s", q.sel, q.tblName, q.where, q.order)
  268. err := Conn.QueryRow(query, q.vals...).Scan(StructPointers(tmpModel.Interface())...)
  269. if err != nil {
  270. return errors.New("db.Query.Get(), " + err.Error())
  271. }
  272. readRuleMethod := tmpModel.MethodByName("ReadRule")
  273. readRulePassed := readRuleMethod.Call(nil)[0].Bool()
  274. if !readRulePassed {
  275. return errors.New("ERROR: db.Query.Get(), read rules not passed")
  276. }
  277. for i := 0; i < tmpModelElem.NumField(); i++ {
  278. q.targetVal.Field(i).Set(tmpModelElem.Field(i))
  279. }
  280. return nil
  281. }
  282. func (q *Query) All() error {
  283. query := fmt.Sprintf("SELECT %s FROM %s %s %s", q.sel, q.tblName, q.where, q.order)
  284. rows, err := Conn.Query(query, q.vals...)
  285. if err != nil {
  286. return err
  287. }
  288. var set []reflect.Value
  289. for rows.Next() {
  290. tmpModel := reflect.New(q.modelType)
  291. tmpModel.Elem().FieldByName("sec").Set(reflect.ValueOf(q.Auth))
  292. rows.Scan(StructPointers(tmpModel.Interface())...)
  293. readRuleMethod := tmpModel.MethodByName("ReadRule")
  294. readRulePassed := readRuleMethod.Call(nil)[0].Bool()
  295. if !readRulePassed {
  296. return errors.New("ERROR: db.Query.All(), read rules not passed")
  297. }
  298. set = append(set, tmpModel.Elem())
  299. }
  300. rows.Close()
  301. newSlice := reflect.Append(q.targetVal, set...)
  302. q.targetVal.Set(newSlice)
  303. return nil
  304. }
  305. func (q *Query) getFields(s interface{}) ([]string, []string, []any) {
  306. targetPtr := reflect.ValueOf(s)
  307. targetVal := targetPtr.Elem()
  308. numFields := targetVal.NumField()
  309. var fields []string
  310. var plcHdr []string
  311. var vals []any
  312. for i := 0; i < numFields; i++ {
  313. fieldName := targetVal.Type().Field(i).Name
  314. if fieldName == q.pkName {
  315. continue
  316. }
  317. fieldVal := targetVal.Field(i)
  318. tag := targetVal.Type().Field(i).Tag
  319. if tag != "" {
  320. continue
  321. }
  322. if fieldVal.Kind() == reflect.Struct {
  323. fieldsVal, plcHdrVal, valsVal := q.getFields(fieldVal.Addr().Interface())
  324. fields = append(fields, fieldsVal...)
  325. plcHdr = append(plcHdr, plcHdrVal...)
  326. vals = append(vals, valsVal...)
  327. continue
  328. }
  329. fields = append(fields, fmt.Sprintf("%v", fieldName))
  330. plcHdr = append(plcHdr, "?")
  331. vals = append(vals, fmt.Sprintf("%v", fieldVal))
  332. }
  333. return fields, plcHdr, vals
  334. }
  335. func (q *Query) SortParams(r *http.Request, defaults *Pagination) {
  336. if r.FormValue("prevPage") != "" {
  337. defaults.PrevPage = r.FormValue("prevPage")
  338. }
  339. if r.FormValue("nextPage") != "" {
  340. defaults.NextPage = r.FormValue("nextPage")
  341. }
  342. if r.FormValue("page") != "" {
  343. defaults.Page, _ = strconv.Atoi(r.FormValue("page"))
  344. }
  345. if r.FormValue("perPage") != "" {
  346. defaults.PerPage, _ = strconv.Atoi(r.FormValue("perPage"))
  347. }
  348. if r.FormValue("order") != "" {
  349. defaults.Order = r.FormValue("order")
  350. }
  351. if r.FormValue("direction") != "" {
  352. defaults.Direction = r.FormValue("direction")
  353. }
  354. defaults.Search = r.FormValue("search")
  355. }
  356. func (q *Query) Sort(defaults *Pagination) *Query {
  357. var count int
  358. query := fmt.Sprintf("select count(*) as count from %s %s", q.tblName, q.where)
  359. _ = Conn.QueryRow(query, q.vals...).Scan(&count)
  360. pages := math.Ceil(float64(count) / float64(defaults.PerPage))
  361. defaults.Pages = int(pages)
  362. if defaults.PrevPage != "" {
  363. defaults.Page--
  364. }
  365. if defaults.NextPage != "" {
  366. defaults.Page++
  367. }
  368. if defaults.Page <= 0 {
  369. defaults.Page = 1
  370. }
  371. if defaults.Page > defaults.Pages {
  372. defaults.Page = defaults.Pages
  373. }
  374. defaults.Start = (defaults.Page - 1) * defaults.PerPage
  375. if defaults.Start < 0 {
  376. defaults.Start = 0
  377. }
  378. q.Order(defaults.Order + " " + defaults.Direction)
  379. // this.limit(sort.start, sort.perPage);
  380. return q
  381. }