|
@@ -1,495 +0,0 @@
|
|
|
-package db
|
|
|
-
|
|
|
-import (
|
|
|
- "errors"
|
|
|
- "fmt"
|
|
|
- "math"
|
|
|
- "net/http"
|
|
|
- "reflect"
|
|
|
- "strconv"
|
|
|
- "strings"
|
|
|
-
|
|
|
- "git.clearsky.net.au/cody/gex.git/sec"
|
|
|
-)
|
|
|
-
|
|
|
-type Pagination struct {
|
|
|
- PrevPage string
|
|
|
- NextPage string
|
|
|
- Page int
|
|
|
- Pages int
|
|
|
- PerPage int
|
|
|
- Order string
|
|
|
- Direction string
|
|
|
- Search string
|
|
|
- Start int
|
|
|
-}
|
|
|
-
|
|
|
-type Query struct {
|
|
|
- targetInt interface{}
|
|
|
- targetPtr reflect.Value
|
|
|
- targetVal reflect.Value
|
|
|
- modelType reflect.Type
|
|
|
-
|
|
|
- tblName string
|
|
|
- pkName string
|
|
|
- Auth *sec.Auth
|
|
|
- err error
|
|
|
- ready bool
|
|
|
-
|
|
|
- sel string
|
|
|
- where string
|
|
|
- order string
|
|
|
- vals []any
|
|
|
-}
|
|
|
-
|
|
|
-func (q *Query) New(targetInt interface{}, auth *sec.Auth) *Query {
|
|
|
-
|
|
|
- q.targetInt = targetInt
|
|
|
- q.targetPtr = reflect.ValueOf(q.targetInt)
|
|
|
-
|
|
|
- if q.targetPtr.Kind().String() != "ptr" {
|
|
|
- q.err = errors.New("ERROR: db.Query.New(), targetInt must be a pointer")
|
|
|
- return q
|
|
|
- }
|
|
|
-
|
|
|
- q.targetVal = q.targetPtr.Elem()
|
|
|
- q.sel = "*"
|
|
|
-
|
|
|
- // Get the secion
|
|
|
- q.Auth = auth
|
|
|
-
|
|
|
- if q.targetVal.Kind().String() == "slice" {
|
|
|
- q.newSlice()
|
|
|
- return q
|
|
|
- }
|
|
|
-
|
|
|
- //Get the type of the model
|
|
|
- q.modelType = q.targetVal.Type()
|
|
|
-
|
|
|
- // Get the table name
|
|
|
- tblMth := q.targetPtr.MethodByName("Table")
|
|
|
- if !tblMth.IsValid() {
|
|
|
- q.err = errors.New("ERROR: db.Query.New(), target Table() method must be defined")
|
|
|
- return q
|
|
|
- }
|
|
|
- q.tblName = tblMth.Call(nil)[0].String()
|
|
|
-
|
|
|
- // Get the primary key name
|
|
|
- pkMth := q.targetPtr.MethodByName("PK")
|
|
|
- if !tblMth.IsValid() {
|
|
|
- q.err = errors.New("ERROR: db.Query.New(), target PK() method must be defined")
|
|
|
- return q
|
|
|
- }
|
|
|
- q.pkName = pkMth.Call(nil)[0].String()
|
|
|
-
|
|
|
- q.ready = true
|
|
|
- return q
|
|
|
-}
|
|
|
-
|
|
|
-func (q *Query) newSlice() {
|
|
|
- // Get the type of the model
|
|
|
- q.modelType = q.targetVal.Type().Elem()
|
|
|
-
|
|
|
- // create a new model so we can run the table and pk methods
|
|
|
- tmpModel := reflect.New(q.modelType)
|
|
|
-
|
|
|
- // Get the table name
|
|
|
- tblMth := tmpModel.MethodByName("Table")
|
|
|
- if !tblMth.IsValid() {
|
|
|
- q.err = errors.New("ERROR: db.Query.New(), target Table() method must be defined")
|
|
|
- }
|
|
|
- q.tblName = tblMth.Call(nil)[0].String()
|
|
|
-
|
|
|
- // Get the primary key name
|
|
|
- pkMth := tmpModel.MethodByName("PK")
|
|
|
- if !tblMth.IsValid() {
|
|
|
- q.err = errors.New("ERROR: db.Query.New(), target PK() method must be defined")
|
|
|
- }
|
|
|
- q.pkName = pkMth.Call(nil)[0].String()
|
|
|
- q.ready = true
|
|
|
-}
|
|
|
-
|
|
|
-func (q *Query) Reset() {
|
|
|
- q.sel = "*"
|
|
|
- q.where = ""
|
|
|
-
|
|
|
- var new []any
|
|
|
- q.vals = new
|
|
|
-}
|
|
|
-
|
|
|
-func (q *Query) Select(s string) *Query {
|
|
|
- if q.err != nil {
|
|
|
- return q
|
|
|
- }
|
|
|
-
|
|
|
- if !q.ready {
|
|
|
- q.err = errors.New("ERROR: db.Query.Select(), db.Query.New() needs to be ran first")
|
|
|
- return q
|
|
|
- }
|
|
|
-
|
|
|
- q.sel = s
|
|
|
- return q
|
|
|
-}
|
|
|
-
|
|
|
-func (q *Query) Where(str string, vals ...any) *Query {
|
|
|
- if q.err != nil {
|
|
|
- return q
|
|
|
- }
|
|
|
-
|
|
|
- if !q.ready {
|
|
|
- q.err = errors.New("ERROR: db.Query.Where(), db.Query.New() needs to run first")
|
|
|
- return q
|
|
|
- }
|
|
|
-
|
|
|
- q.where = fmt.Sprintf("WHERE (%s) ", str)
|
|
|
- q.vals = append(q.vals, vals...)
|
|
|
- return q
|
|
|
-}
|
|
|
-
|
|
|
-// Security check here to ensure str is a field (column)
|
|
|
-func (q *Query) Order(str string) *Query {
|
|
|
- q.order += "ORDER BY " + str
|
|
|
- return q
|
|
|
-}
|
|
|
-
|
|
|
-func (q *Query) Create() (int64, error) {
|
|
|
- if q.err != nil {
|
|
|
- return 0, q.err
|
|
|
- }
|
|
|
- if !q.ready {
|
|
|
- q.err = errors.New("ERROR: db.Query.Create(), db.Query.New() needs to run first")
|
|
|
- return 0, q.err
|
|
|
- }
|
|
|
-
|
|
|
- if q.targetVal.Kind().String() != "struct" {
|
|
|
- return 0, errors.New("ERROR: db.Query.Create() model must be a struct (are you using a slice)")
|
|
|
- }
|
|
|
-
|
|
|
- fields, plcHdr, vals := q.getFields(q.targetPtr.Interface())
|
|
|
-
|
|
|
- for i := 0; i < len(vals); i++ {
|
|
|
- if vals[i] == "<nil>" {
|
|
|
- vals[i] = nil
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- query := fmt.Sprintf(
|
|
|
- "INSERT INTO %s (%s) VALUES (%s)", q.tblName, strings.Join(fields, ", "), strings.Join(plcHdr, ", "),
|
|
|
- )
|
|
|
-
|
|
|
- q.targetVal.FieldByName("sec").Set(reflect.ValueOf(&q.Auth))
|
|
|
- createRuleMethod := q.targetPtr.MethodByName("CreateRule")
|
|
|
- createRulePassed := createRuleMethod.Call(nil)[0].Bool()
|
|
|
-
|
|
|
- if !createRulePassed {
|
|
|
- return 0, errors.New("ERROR: db.Query.Create(), create permissions not passed")
|
|
|
- }
|
|
|
-
|
|
|
- res, err := Conn.Exec(query, vals...)
|
|
|
-
|
|
|
- if err != nil {
|
|
|
- stringVals := make([]string, len(vals))
|
|
|
- for i, v := range vals {
|
|
|
- stringVals[i] = fmt.Sprint(v)
|
|
|
- }
|
|
|
- fmt.Printf("QUERY ERROR: db.Create(), Query: `%s` Values: `%s` Error: %s ", query, strings.Join(stringVals, ","), err.Error())
|
|
|
-
|
|
|
- return 0, err
|
|
|
- }
|
|
|
-
|
|
|
- lastId, _ := res.LastInsertId()
|
|
|
-
|
|
|
- q.targetVal.FieldByName(q.pkName).Set(reflect.ValueOf(int(lastId)))
|
|
|
- return lastId, nil
|
|
|
-}
|
|
|
-
|
|
|
-func (q *Query) Read(id int) error {
|
|
|
- if q.err != nil {
|
|
|
- return q.err
|
|
|
- }
|
|
|
-
|
|
|
- if !q.ready {
|
|
|
- q.err = errors.New("ERROR: db.Query.Read(), db.Query.New() needs to run first")
|
|
|
- return q.err
|
|
|
- }
|
|
|
-
|
|
|
- if q.targetVal.Kind().String() != "struct" {
|
|
|
- return errors.New("ERROR: db.Query.Read(), model must be a struct (are you using a slice)")
|
|
|
- }
|
|
|
-
|
|
|
- err := q.Where(q.pkName+" = ?", id).Get()
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
- return nil
|
|
|
-}
|
|
|
-
|
|
|
-func (q *Query) Update() error {
|
|
|
- if q.err != nil {
|
|
|
- return q.err
|
|
|
- }
|
|
|
- if !q.ready {
|
|
|
- q.err = errors.New("ERROR: db.Query.Update(), db.Query.New() needs to run first")
|
|
|
- return q.err
|
|
|
- }
|
|
|
-
|
|
|
- if q.targetVal.Kind().String() != "struct" {
|
|
|
- return errors.New("ERROR: db.Query.Update() model must be a struct (are you using a slice)")
|
|
|
- }
|
|
|
-
|
|
|
- // Check the existing data from db to check rule
|
|
|
- id := q.targetVal.FieldByName(q.pkName)
|
|
|
- tmpModel := reflect.New(q.modelType)
|
|
|
-
|
|
|
- tmpQuery := Query{}
|
|
|
- q.err = tmpQuery.New(tmpModel.Interface(), q.Auth).Read(int(id.Int()))
|
|
|
-
|
|
|
- if q.err != nil {
|
|
|
- return q.err
|
|
|
- }
|
|
|
-
|
|
|
- existRuleMethod := tmpModel.MethodByName("UpdateRule")
|
|
|
- existRulePassed := existRuleMethod.Call(nil)[0].Bool()
|
|
|
-
|
|
|
- newRuleMethod := q.targetPtr.MethodByName("UpdateRule")
|
|
|
- newRulePassed := newRuleMethod.Call(nil)[0].Bool()
|
|
|
-
|
|
|
- if !existRulePassed || !newRulePassed {
|
|
|
- return errors.New("ERROR: db.Query.Update() ReadRule not passed)")
|
|
|
- }
|
|
|
-
|
|
|
- fields, _, vals := q.getFields(q.targetInt)
|
|
|
-
|
|
|
- for i := 0; i < len(fields); i++ {
|
|
|
- fields[i] = fields[i] + "=?"
|
|
|
- }
|
|
|
-
|
|
|
- for i := 0; i < len(vals); i++ {
|
|
|
- if vals[i] == "<nil>" {
|
|
|
- vals[i] = nil
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- vals = append(vals, int(id.Int()))
|
|
|
- query := fmt.Sprintf(
|
|
|
- "UPDATE %s SET %s WHERE %s=?", q.tblName, strings.Join(fields, ", "), q.pkName,
|
|
|
- )
|
|
|
-
|
|
|
- _, err := Conn.Exec(query, vals...)
|
|
|
- if err != nil {
|
|
|
- fmt.Printf("--QUERY ERROR (db.Update()): `%s` %s ", query, err.Error())
|
|
|
- fmt.Println("\nQuery Values: ", vals)
|
|
|
- q.Reset()
|
|
|
- return err
|
|
|
- }
|
|
|
-
|
|
|
- return nil
|
|
|
-}
|
|
|
-
|
|
|
-func (q *Query) Delete() error {
|
|
|
- if q.err != nil {
|
|
|
- return q.err
|
|
|
- }
|
|
|
- if !q.ready {
|
|
|
- q.err = errors.New("db.Query.New() needs to run first")
|
|
|
- return q.err
|
|
|
- }
|
|
|
- if q.targetVal.Kind().String() != "struct" {
|
|
|
- return errors.New("db.Query.Delete() model must be a struct (are you trying a slice or array?)")
|
|
|
- }
|
|
|
-
|
|
|
- // Check the existing data from db to check rule
|
|
|
- id := q.targetVal.FieldByName(q.pkName)
|
|
|
- tmpModel := reflect.New(q.targetVal.Type())
|
|
|
-
|
|
|
- tmpQuery := Query{}
|
|
|
- tmpQuery.New(tmpModel.Interface(), q.Auth).Read(int(id.Int()))
|
|
|
-
|
|
|
- existRuleMethod := tmpModel.MethodByName("DeleteRule")
|
|
|
- existRulePassed := existRuleMethod.Call(nil)[0].Bool()
|
|
|
-
|
|
|
- if !existRulePassed {
|
|
|
- err := "Update rule failed"
|
|
|
- fmt.Println(err)
|
|
|
- return errors.New(err)
|
|
|
- }
|
|
|
-
|
|
|
- query := fmt.Sprintf("DELETE FROM %s WHERE %s = ?", q.tblName, q.pkName)
|
|
|
-
|
|
|
- _, err := Conn.Exec(query, int(id.Int()))
|
|
|
-
|
|
|
- if err != nil {
|
|
|
- return errors.New("db.Query.Get(), " + err.Error())
|
|
|
- }
|
|
|
- q.Reset()
|
|
|
- return nil
|
|
|
-}
|
|
|
-
|
|
|
-func (q *Query) Get() error {
|
|
|
-
|
|
|
- q.targetVal.FieldByName("Auth").Set(reflect.ValueOf(q.Auth))
|
|
|
-
|
|
|
- // Create a temp model to store the data
|
|
|
- tmpModel := reflect.New(q.modelType)
|
|
|
- tmpModelElem := tmpModel.Elem()
|
|
|
- tmpModelElem.FieldByName("Auth").Set(reflect.ValueOf(q.Auth))
|
|
|
- //
|
|
|
-
|
|
|
- query := fmt.Sprintf("SELECT %s FROM %s %s %s", q.sel, q.tblName, q.where, q.order)
|
|
|
-
|
|
|
- err := Conn.QueryRow(query, q.vals...).Scan(StructPointers(tmpModel.Interface())...)
|
|
|
- if err != nil {
|
|
|
- return errors.New("db.Query.Get(), " + err.Error())
|
|
|
- }
|
|
|
-
|
|
|
- readRuleMethod := tmpModel.MethodByName("ReadRule")
|
|
|
- readRulePassed := readRuleMethod.Call(nil)[0].Bool()
|
|
|
-
|
|
|
- if !readRulePassed {
|
|
|
- return errors.New("ERROR: db.Query.Get(), read rules not passed")
|
|
|
- }
|
|
|
- for i := 0; i < tmpModelElem.NumField(); i++ {
|
|
|
- q.targetVal.Field(i).Set(tmpModelElem.Field(i))
|
|
|
- }
|
|
|
- return nil
|
|
|
-}
|
|
|
-
|
|
|
-func (q *Query) All() error {
|
|
|
-
|
|
|
- query := fmt.Sprintf("SELECT %s FROM %s %s %s", q.sel, q.tblName, q.where, q.order)
|
|
|
-
|
|
|
- rows, err := Conn.Query(query, q.vals...)
|
|
|
-
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
-
|
|
|
- var set []reflect.Value
|
|
|
- for rows.Next() {
|
|
|
- tmpModel := reflect.New(q.modelType)
|
|
|
- tmpModel.Elem().FieldByName("sec").Set(reflect.ValueOf(q.Auth))
|
|
|
- rows.Scan(StructPointers(tmpModel.Interface())...)
|
|
|
-
|
|
|
- readRuleMethod := tmpModel.MethodByName("ReadRule")
|
|
|
- readRulePassed := readRuleMethod.Call(nil)[0].Bool()
|
|
|
-
|
|
|
- if !readRulePassed {
|
|
|
- return errors.New("ERROR: db.Query.All(), read rules not passed")
|
|
|
- }
|
|
|
- set = append(set, tmpModel.Elem())
|
|
|
- }
|
|
|
- rows.Close()
|
|
|
-
|
|
|
- newSlice := reflect.Append(q.targetVal, set...)
|
|
|
- q.targetVal.Set(newSlice)
|
|
|
- return nil
|
|
|
-}
|
|
|
-
|
|
|
-func (q *Query) getFields(s interface{}) ([]string, []string, []any) {
|
|
|
- targetPtr := reflect.ValueOf(s)
|
|
|
- targetVal := targetPtr.Elem()
|
|
|
-
|
|
|
- numFields := targetVal.NumField()
|
|
|
-
|
|
|
- var fields []string
|
|
|
- var plcHdr []string
|
|
|
- var vals []any
|
|
|
-
|
|
|
- for i := 0; i < numFields; i++ {
|
|
|
- fieldName := targetVal.Type().Field(i).Name
|
|
|
- if fieldName == q.pkName {
|
|
|
- continue
|
|
|
- }
|
|
|
-
|
|
|
- fieldVal := targetVal.Field(i)
|
|
|
-
|
|
|
- tag := targetVal.Type().Field(i).Tag
|
|
|
-
|
|
|
- if tag != "" {
|
|
|
- continue
|
|
|
- }
|
|
|
-
|
|
|
- if fieldVal.Kind() == reflect.Struct {
|
|
|
- fieldsVal, plcHdrVal, valsVal := q.getFields(fieldVal.Addr().Interface())
|
|
|
- fields = append(fields, fieldsVal...)
|
|
|
- plcHdr = append(plcHdr, plcHdrVal...)
|
|
|
- vals = append(vals, valsVal...)
|
|
|
- continue
|
|
|
- }
|
|
|
-
|
|
|
- fields = append(fields, fmt.Sprintf("%v", fieldName))
|
|
|
- plcHdr = append(plcHdr, "?")
|
|
|
-
|
|
|
- vals = append(vals, fmt.Sprintf("%v", fieldVal))
|
|
|
- }
|
|
|
-
|
|
|
- return fields, plcHdr, vals
|
|
|
-}
|
|
|
-
|
|
|
-func (q *Query) SortParams(r *http.Request, defaults *Pagination) {
|
|
|
- if r.FormValue("prevPage") != "" {
|
|
|
- defaults.PrevPage = r.FormValue("prevPage")
|
|
|
- }
|
|
|
-
|
|
|
- if r.FormValue("nextPage") != "" {
|
|
|
- defaults.NextPage = r.FormValue("nextPage")
|
|
|
- }
|
|
|
-
|
|
|
- if r.FormValue("page") != "" {
|
|
|
- defaults.Page, _ = strconv.Atoi(r.FormValue("page"))
|
|
|
- }
|
|
|
-
|
|
|
- if r.FormValue("perPage") != "" {
|
|
|
- defaults.PerPage, _ = strconv.Atoi(r.FormValue("perPage"))
|
|
|
- }
|
|
|
-
|
|
|
- if r.FormValue("order") != "" {
|
|
|
- defaults.Order = r.FormValue("order")
|
|
|
- }
|
|
|
-
|
|
|
- if r.FormValue("direction") != "" {
|
|
|
- defaults.Direction = r.FormValue("direction")
|
|
|
- }
|
|
|
-
|
|
|
- defaults.Search = r.FormValue("search")
|
|
|
-
|
|
|
-}
|
|
|
-
|
|
|
-func (q *Query) Sort(defaults *Pagination) *Query {
|
|
|
-
|
|
|
- var count int
|
|
|
-
|
|
|
- query := fmt.Sprintf("select count(*) as count from %s %s", q.tblName, q.where)
|
|
|
- _ = Conn.QueryRow(query, q.vals...).Scan(&count)
|
|
|
-
|
|
|
- pages := math.Ceil(float64(count) / float64(defaults.PerPage))
|
|
|
- defaults.Pages = int(pages)
|
|
|
-
|
|
|
- if defaults.PrevPage != "" {
|
|
|
- defaults.Page--
|
|
|
- }
|
|
|
-
|
|
|
- if defaults.NextPage != "" {
|
|
|
- defaults.Page++
|
|
|
- }
|
|
|
-
|
|
|
- if defaults.Page <= 0 {
|
|
|
- defaults.Page = 1
|
|
|
- }
|
|
|
-
|
|
|
- if defaults.Page > defaults.Pages {
|
|
|
- defaults.Page = defaults.Pages
|
|
|
-
|
|
|
- }
|
|
|
-
|
|
|
- defaults.Start = (defaults.Page - 1) * defaults.PerPage
|
|
|
-
|
|
|
- if defaults.Start < 0 {
|
|
|
- defaults.Start = 0
|
|
|
- }
|
|
|
-
|
|
|
- q.Order(defaults.Order + " " + defaults.Direction)
|
|
|
- // this.limit(sort.start, sort.perPage);
|
|
|
-
|
|
|
- return q
|
|
|
-}
|