|
@@ -0,0 +1,490 @@
|
|
|
+package db
|
|
|
+
|
|
|
+import (
|
|
|
+ "errors"
|
|
|
+ "fmt"
|
|
|
+ "math"
|
|
|
+ "net/http"
|
|
|
+ "reflect"
|
|
|
+ "strconv"
|
|
|
+ "strings"
|
|
|
+
|
|
|
+ "git.clearsky.net.au/cody/gex.git/sec"
|
|
|
+)
|
|
|
+
|
|
|
+type Pagination struct {
|
|
|
+ PrevPage string
|
|
|
+ NextPage string
|
|
|
+ Page int
|
|
|
+ Pages int
|
|
|
+ PerPage int
|
|
|
+ Order string
|
|
|
+ Direction string
|
|
|
+ Search string
|
|
|
+ Start int
|
|
|
+}
|
|
|
+
|
|
|
+type Query struct {
|
|
|
+ targetInt interface{}
|
|
|
+ targetPtr reflect.Value
|
|
|
+ targetVal reflect.Value
|
|
|
+ modelType reflect.Type
|
|
|
+
|
|
|
+ tblName string
|
|
|
+ pkName string
|
|
|
+ Sess sec.Sess
|
|
|
+ err error
|
|
|
+ ready bool
|
|
|
+
|
|
|
+ sel string
|
|
|
+ where string
|
|
|
+ order string
|
|
|
+ vals []any
|
|
|
+}
|
|
|
+
|
|
|
+func (q *Query) New(targetInt interface{}, sess sec.Sess) *Query {
|
|
|
+ q.targetInt = targetInt
|
|
|
+ q.targetPtr = reflect.ValueOf(q.targetInt)
|
|
|
+
|
|
|
+ if q.targetPtr.Kind().String() != "ptr" {
|
|
|
+ q.err = errors.New("ERROR: db.Query.New(), targetInt must be a pointer")
|
|
|
+ return q
|
|
|
+ }
|
|
|
+
|
|
|
+ q.targetVal = q.targetPtr.Elem()
|
|
|
+ q.sel = "*"
|
|
|
+
|
|
|
+ // Get the session
|
|
|
+ q.Sess = sess
|
|
|
+
|
|
|
+ if q.targetVal.Kind().String() == "slice" {
|
|
|
+ q.newSlice()
|
|
|
+ return q
|
|
|
+ }
|
|
|
+
|
|
|
+ //Get the type of the model
|
|
|
+ q.modelType = q.targetVal.Type()
|
|
|
+
|
|
|
+ // Get the table name
|
|
|
+ tblMth := q.targetPtr.MethodByName("Table")
|
|
|
+ if !tblMth.IsValid() {
|
|
|
+ q.err = errors.New("ERROR: db.Query.New(), target Table() method must be defined")
|
|
|
+ return q
|
|
|
+ }
|
|
|
+ q.tblName = tblMth.Call(nil)[0].String()
|
|
|
+
|
|
|
+ // Get the primary key name
|
|
|
+ pkMth := q.targetPtr.MethodByName("PK")
|
|
|
+ if !tblMth.IsValid() {
|
|
|
+ q.err = errors.New("ERROR: db.Query.New(), target PK() method must be defined")
|
|
|
+ return q
|
|
|
+ }
|
|
|
+ q.pkName = pkMth.Call(nil)[0].String()
|
|
|
+
|
|
|
+ q.ready = true
|
|
|
+ return q
|
|
|
+}
|
|
|
+
|
|
|
+func (q *Query) newSlice() {
|
|
|
+ // Get the type of the model
|
|
|
+ q.modelType = q.targetVal.Type().Elem()
|
|
|
+
|
|
|
+ // create a new model so we can run the table and pk methods
|
|
|
+ tmpModel := reflect.New(q.modelType)
|
|
|
+
|
|
|
+ // Get the table name
|
|
|
+ tblMth := tmpModel.MethodByName("Table")
|
|
|
+ if !tblMth.IsValid() {
|
|
|
+ q.err = errors.New("ERROR: db.Query.New(), target Table() method must be defined")
|
|
|
+ }
|
|
|
+ q.tblName = tblMth.Call(nil)[0].String()
|
|
|
+
|
|
|
+ // Get the primary key name
|
|
|
+ pkMth := tmpModel.MethodByName("PK")
|
|
|
+ if !tblMth.IsValid() {
|
|
|
+ q.err = errors.New("ERROR: db.Query.New(), target PK() method must be defined")
|
|
|
+ }
|
|
|
+ q.pkName = pkMth.Call(nil)[0].String()
|
|
|
+ q.ready = true
|
|
|
+}
|
|
|
+
|
|
|
+func (q *Query) Reset() {
|
|
|
+ q.sel = "*"
|
|
|
+ q.where = ""
|
|
|
+
|
|
|
+ var new []any
|
|
|
+ q.vals = new
|
|
|
+}
|
|
|
+
|
|
|
+func (q *Query) Select(s string) *Query {
|
|
|
+ if q.err != nil {
|
|
|
+ return q
|
|
|
+ }
|
|
|
+
|
|
|
+ if !q.ready {
|
|
|
+ q.err = errors.New("ERROR: db.Query.Select(), db.Query.New() needs to be ran first")
|
|
|
+ return q
|
|
|
+ }
|
|
|
+
|
|
|
+ q.sel = s
|
|
|
+ return q
|
|
|
+}
|
|
|
+
|
|
|
+func (q *Query) Where(str string, vals ...any) *Query {
|
|
|
+ if q.err != nil {
|
|
|
+ return q
|
|
|
+ }
|
|
|
+
|
|
|
+ if !q.ready {
|
|
|
+ q.err = errors.New("ERROR: db.Query.Where(), db.Query.New() needs to run first")
|
|
|
+ return q
|
|
|
+ }
|
|
|
+
|
|
|
+ q.where = fmt.Sprintf("WHERE (%s) ", str)
|
|
|
+ q.vals = append(q.vals, vals...)
|
|
|
+ return q
|
|
|
+}
|
|
|
+
|
|
|
+// Security check here to ensure str is a field (column)
|
|
|
+func (q *Query) Order(str string) *Query {
|
|
|
+ q.order += "ORDER BY " + str
|
|
|
+ return q
|
|
|
+}
|
|
|
+
|
|
|
+func (q *Query) Create() (int64, error) {
|
|
|
+ if q.err != nil {
|
|
|
+ return 0, q.err
|
|
|
+ }
|
|
|
+ if !q.ready {
|
|
|
+ q.err = errors.New("ERROR: db.Query.Create(), db.Query.New() needs to run first")
|
|
|
+ return 0, q.err
|
|
|
+ }
|
|
|
+
|
|
|
+ if q.targetVal.Kind().String() != "struct" {
|
|
|
+ return 0, errors.New("ERROR: db.Query.Create() model must be a struct (are you using a slice)")
|
|
|
+ }
|
|
|
+
|
|
|
+ fields, plcHdr, vals := q.getFields(q.targetPtr.Interface())
|
|
|
+
|
|
|
+ for i := 0; i < len(vals); i++ {
|
|
|
+ if vals[i] == "<nil>" {
|
|
|
+ vals[i] = nil
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ query := fmt.Sprintf(
|
|
|
+ "INSERT INTO %s (%s) VALUES (%s)", q.tblName, strings.Join(fields, ", "), strings.Join(plcHdr, ", "),
|
|
|
+ )
|
|
|
+
|
|
|
+ q.targetVal.FieldByName("Sess").Set(reflect.ValueOf(&q.Sess))
|
|
|
+ createRuleMethod := q.targetPtr.MethodByName("CreateRule")
|
|
|
+ createRulePassed := createRuleMethod.Call(nil)[0].Bool()
|
|
|
+
|
|
|
+ if !createRulePassed {
|
|
|
+ return 0, errors.New("ERROR: db.Query.Create(), create permissions not passed")
|
|
|
+ }
|
|
|
+
|
|
|
+ res, err := Conn.Exec(query, vals...)
|
|
|
+
|
|
|
+ if err != nil {
|
|
|
+ stringVals := make([]string, len(vals))
|
|
|
+ for i, v := range vals {
|
|
|
+ stringVals[i] = fmt.Sprint(v)
|
|
|
+ }
|
|
|
+ fmt.Printf("QUERY ERROR: db.Create(), Query: `%s` Values: `%s` Error: %s ", query, strings.Join(stringVals, ","), err.Error())
|
|
|
+
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+
|
|
|
+ lastId, _ := res.LastInsertId()
|
|
|
+
|
|
|
+ q.targetVal.FieldByName(q.pkName).Set(reflect.ValueOf(int(lastId)))
|
|
|
+ return lastId, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (q *Query) Read(id int) error {
|
|
|
+ if q.err != nil {
|
|
|
+ return q.err
|
|
|
+ }
|
|
|
+
|
|
|
+ if !q.ready {
|
|
|
+ q.err = errors.New("ERROR: db.Query.Read(), db.Query.New() needs to run first")
|
|
|
+ return q.err
|
|
|
+ }
|
|
|
+
|
|
|
+ if q.targetVal.Kind().String() != "struct" {
|
|
|
+ return errors.New("ERROR: db.Query.Read(), model must be a struct (are you using a slice)")
|
|
|
+ }
|
|
|
+
|
|
|
+ err := q.Where(q.pkName+" = ?", id).Get()
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func (q *Query) Update() error {
|
|
|
+ if q.err != nil {
|
|
|
+ return q.err
|
|
|
+ }
|
|
|
+ if !q.ready {
|
|
|
+ q.err = errors.New("ERROR: db.Query.Update(), db.Query.New() needs to run first")
|
|
|
+ return q.err
|
|
|
+ }
|
|
|
+
|
|
|
+ if q.targetVal.Kind().String() != "struct" {
|
|
|
+ return errors.New("ERROR: db.Query.Update() model must be a struct (are you using a slice)")
|
|
|
+ }
|
|
|
+
|
|
|
+ // Check the existing data from db to check rule
|
|
|
+ id := q.targetVal.FieldByName(q.pkName)
|
|
|
+ tmpModel := reflect.New(q.modelType)
|
|
|
+
|
|
|
+ tmpQuery := Query{}
|
|
|
+ q.err = tmpQuery.New(tmpModel.Interface(), q.Sess).Read(int(id.Int()))
|
|
|
+
|
|
|
+ if q.err != nil {
|
|
|
+ return q.err
|
|
|
+ }
|
|
|
+
|
|
|
+ existRuleMethod := tmpModel.MethodByName("UpdateRule")
|
|
|
+ existRulePassed := existRuleMethod.Call(nil)[0].Bool()
|
|
|
+
|
|
|
+ newRuleMethod := q.targetPtr.MethodByName("UpdateRule")
|
|
|
+ newRulePassed := newRuleMethod.Call(nil)[0].Bool()
|
|
|
+
|
|
|
+ if !existRulePassed || !newRulePassed {
|
|
|
+ return errors.New("ERROR: db.Query.Update() ReadRule not passed)")
|
|
|
+ }
|
|
|
+
|
|
|
+ fields, _, vals := q.getFields(q.targetInt)
|
|
|
+
|
|
|
+ for i := 0; i < len(fields); i++ {
|
|
|
+ fields[i] = fields[i] + "=?"
|
|
|
+ }
|
|
|
+
|
|
|
+ for i := 0; i < len(vals); i++ {
|
|
|
+ if vals[i] == "<nil>" {
|
|
|
+ vals[i] = nil
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ vals = append(vals, int(id.Int()))
|
|
|
+ query := fmt.Sprintf(
|
|
|
+ "UPDATE %s SET %s WHERE %s=?", q.tblName, strings.Join(fields, ", "), q.pkName,
|
|
|
+ )
|
|
|
+
|
|
|
+ _, err := Conn.Exec(query, vals...)
|
|
|
+ if err != nil {
|
|
|
+ fmt.Printf("--QUERY ERROR (db.Update()): `%s` %s ", query, err.Error())
|
|
|
+ fmt.Println("\nQuery Values: ", vals)
|
|
|
+ q.Reset()
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func (q *Query) Delete() error {
|
|
|
+ if q.err != nil {
|
|
|
+ return q.err
|
|
|
+ }
|
|
|
+ if !q.ready {
|
|
|
+ q.err = errors.New("db.Query.New() needs to run first")
|
|
|
+ return q.err
|
|
|
+ }
|
|
|
+ if q.targetVal.Kind().String() != "struct" {
|
|
|
+ return errors.New("db.Query.Delete() model must be a struct (are you trying a slice or array?)")
|
|
|
+ }
|
|
|
+
|
|
|
+ // Check the existing data from db to check rule
|
|
|
+ id := q.targetVal.FieldByName(q.pkName)
|
|
|
+ tmpModel := reflect.New(q.targetVal.Type())
|
|
|
+
|
|
|
+ tmpQuery := Query{}
|
|
|
+ tmpQuery.New(tmpModel.Interface(), q.Sess).Read(int(id.Int()))
|
|
|
+
|
|
|
+ existRuleMethod := tmpModel.MethodByName("DeleteRule")
|
|
|
+ existRulePassed := existRuleMethod.Call(nil)[0].Bool()
|
|
|
+
|
|
|
+ if !existRulePassed {
|
|
|
+ err := "Update rule failed"
|
|
|
+ fmt.Println(err)
|
|
|
+ return errors.New(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ query := fmt.Sprintf("DELETE FROM %s WHERE %s = ?", q.tblName, q.pkName)
|
|
|
+
|
|
|
+ _, err := Conn.Exec(query, int(id.Int()))
|
|
|
+
|
|
|
+ if err != nil {
|
|
|
+ return errors.New("db.Query.Get(), " + err.Error())
|
|
|
+ }
|
|
|
+ q.Reset()
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func (q *Query) Get() error {
|
|
|
+ q.targetVal.FieldByName("Sess").Set(reflect.ValueOf(&q.Sess))
|
|
|
+
|
|
|
+ // Create a temp model to store the data
|
|
|
+ tmpModel := reflect.New(q.modelType)
|
|
|
+ tmpModelElem := tmpModel.Elem()
|
|
|
+ tmpModelElem.FieldByName("Sess").Set(reflect.ValueOf(&q.Sess))
|
|
|
+ //
|
|
|
+
|
|
|
+ query := fmt.Sprintf("SELECT %s FROM %s %s %s", q.sel, q.tblName, q.where, q.order)
|
|
|
+ err := Conn.QueryRow(query, q.vals...).Scan(StructPointers(tmpModel.Interface())...)
|
|
|
+ if err != nil {
|
|
|
+ return errors.New("db.Query.Get(), " + err.Error())
|
|
|
+ }
|
|
|
+
|
|
|
+ readRuleMethod := tmpModel.MethodByName("ReadRule")
|
|
|
+ readRulePassed := readRuleMethod.Call(nil)[0].Bool()
|
|
|
+
|
|
|
+ if !readRulePassed {
|
|
|
+ return errors.New("ERROR: db.Query.Get(), read rules not passed")
|
|
|
+ }
|
|
|
+ for i := 0; i < tmpModelElem.NumField(); i++ {
|
|
|
+ q.targetVal.Field(i).Set(tmpModelElem.Field(i))
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func (q *Query) All() error {
|
|
|
+ query := fmt.Sprintf("SELECT %s FROM %s %s %s", q.sel, q.tblName, q.where, q.order)
|
|
|
+ rows, err := Conn.Query(query, q.vals...)
|
|
|
+
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ var set []reflect.Value
|
|
|
+ for rows.Next() {
|
|
|
+ tmpModel := reflect.New(q.modelType)
|
|
|
+ tmpModel.Elem().FieldByName("Sess").Set(reflect.ValueOf(&q.Sess))
|
|
|
+ rows.Scan(StructPointers(tmpModel.Interface())...)
|
|
|
+
|
|
|
+ readRuleMethod := tmpModel.MethodByName("ReadRule")
|
|
|
+ readRulePassed := readRuleMethod.Call(nil)[0].Bool()
|
|
|
+
|
|
|
+ if !readRulePassed {
|
|
|
+ return errors.New("ERROR: db.Query.All(), read rules not passed")
|
|
|
+ }
|
|
|
+ set = append(set, tmpModel.Elem())
|
|
|
+ }
|
|
|
+ rows.Close()
|
|
|
+
|
|
|
+ newSlice := reflect.Append(q.targetVal, set...)
|
|
|
+ q.targetVal.Set(newSlice)
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func (q *Query) getFields(s interface{}) ([]string, []string, []any) {
|
|
|
+ targetPtr := reflect.ValueOf(s)
|
|
|
+ targetVal := targetPtr.Elem()
|
|
|
+
|
|
|
+ numFields := targetVal.NumField()
|
|
|
+
|
|
|
+ var fields []string
|
|
|
+ var plcHdr []string
|
|
|
+ var vals []any
|
|
|
+
|
|
|
+ for i := 0; i < numFields; i++ {
|
|
|
+ fieldName := targetVal.Type().Field(i).Name
|
|
|
+ if fieldName == q.pkName {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ fieldVal := targetVal.Field(i)
|
|
|
+
|
|
|
+ tag := targetVal.Type().Field(i).Tag
|
|
|
+
|
|
|
+ if tag != "" {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ if fieldVal.Kind() == reflect.Struct {
|
|
|
+ fieldsVal, plcHdrVal, valsVal := q.getFields(fieldVal.Addr().Interface())
|
|
|
+ fields = append(fields, fieldsVal...)
|
|
|
+ plcHdr = append(plcHdr, plcHdrVal...)
|
|
|
+ vals = append(vals, valsVal...)
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ fields = append(fields, fmt.Sprintf("%v", fieldName))
|
|
|
+ plcHdr = append(plcHdr, "?")
|
|
|
+
|
|
|
+ vals = append(vals, fmt.Sprintf("%v", fieldVal))
|
|
|
+ }
|
|
|
+
|
|
|
+ return fields, plcHdr, vals
|
|
|
+}
|
|
|
+
|
|
|
+func (q *Query) SortParams(r *http.Request, defaults *Pagination) {
|
|
|
+ if r.FormValue("prevPage") != "" {
|
|
|
+ defaults.PrevPage = r.FormValue("prevPage")
|
|
|
+ }
|
|
|
+
|
|
|
+ if r.FormValue("nextPage") != "" {
|
|
|
+ defaults.NextPage = r.FormValue("nextPage")
|
|
|
+ }
|
|
|
+
|
|
|
+ if r.FormValue("page") != "" {
|
|
|
+ defaults.Page, _ = strconv.Atoi(r.FormValue("page"))
|
|
|
+ }
|
|
|
+
|
|
|
+ if r.FormValue("perPage") != "" {
|
|
|
+ defaults.PerPage, _ = strconv.Atoi(r.FormValue("perPage"))
|
|
|
+ }
|
|
|
+
|
|
|
+ if r.FormValue("order") != "" {
|
|
|
+ defaults.Order = r.FormValue("order")
|
|
|
+ }
|
|
|
+
|
|
|
+ if r.FormValue("direction") != "" {
|
|
|
+ defaults.Direction = r.FormValue("direction")
|
|
|
+ }
|
|
|
+
|
|
|
+ defaults.Search = r.FormValue("search")
|
|
|
+
|
|
|
+}
|
|
|
+
|
|
|
+func (q *Query) Sort(defaults *Pagination) *Query {
|
|
|
+
|
|
|
+ var count int
|
|
|
+
|
|
|
+ query := fmt.Sprintf("select count(*) as count from %s %s", q.tblName, q.where)
|
|
|
+ _ = Conn.QueryRow(query, q.vals...).Scan(&count)
|
|
|
+
|
|
|
+ pages := math.Ceil(float64(count) / float64(defaults.PerPage))
|
|
|
+ defaults.Pages = int(pages)
|
|
|
+
|
|
|
+ if defaults.PrevPage != "" {
|
|
|
+ defaults.Page--
|
|
|
+ }
|
|
|
+
|
|
|
+ if defaults.NextPage != "" {
|
|
|
+ defaults.Page++
|
|
|
+ }
|
|
|
+
|
|
|
+ if defaults.Page <= 0 {
|
|
|
+ defaults.Page = 1
|
|
|
+ }
|
|
|
+
|
|
|
+ if defaults.Page > defaults.Pages {
|
|
|
+ defaults.Page = defaults.Pages
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ defaults.Start = (defaults.Page - 1) * defaults.PerPage
|
|
|
+
|
|
|
+ if defaults.Start < 0 {
|
|
|
+ defaults.Start = 0
|
|
|
+ }
|
|
|
+
|
|
|
+ q.Order(defaults.Order + " " + defaults.Direction)
|
|
|
+ // this.limit(sort.start, sort.perPage);
|
|
|
+
|
|
|
+ return q
|
|
|
+}
|