|
- package db
- import (
- "errors"
- "fmt"
- "math"
- "net/http"
- "reflect"
- "strconv"
- "strings"
- "git.clearsky.net.au/cody/gex.git/sec"
- )
- type Pagination struct {
- PrevPage string
- NextPage string
- Page int
- Pages int
- PerPage int
- Order string
- Direction string
- Search string
- Start int
- }
- type Query struct {
- targetInt interface{}
- targetPtr reflect.Value
- targetVal reflect.Value
- modelType reflect.Type
- tblName string
- pkName string
- Sess sec.Sess
- err error
- ready bool
- sel string
- where string
- order string
- vals []any
- }
- func (q *Query) New(targetInt interface{}, sess sec.Sess) *Query {
- q.targetInt = targetInt
- q.targetPtr = reflect.ValueOf(q.targetInt)
- if q.targetPtr.Kind().String() != "ptr" {
- q.err = errors.New("ERROR: db.Query.New(), targetInt must be a pointer")
- return q
- }
- q.targetVal = q.targetPtr.Elem()
- q.sel = "*"
- // Get the session
- q.Sess = sess
- if q.targetVal.Kind().String() == "slice" {
- q.newSlice()
- return q
- }
- //Get the type of the model
- q.modelType = q.targetVal.Type()
- // Get the table name
- tblMth := q.targetPtr.MethodByName("Table")
- if !tblMth.IsValid() {
- q.err = errors.New("ERROR: db.Query.New(), target Table() method must be defined")
- return q
- }
- q.tblName = tblMth.Call(nil)[0].String()
- // Get the primary key name
- pkMth := q.targetPtr.MethodByName("PK")
- if !tblMth.IsValid() {
- q.err = errors.New("ERROR: db.Query.New(), target PK() method must be defined")
- return q
- }
- q.pkName = pkMth.Call(nil)[0].String()
- q.ready = true
- return q
- }
- func (q *Query) newSlice() {
- // Get the type of the model
- q.modelType = q.targetVal.Type().Elem()
- // create a new model so we can run the table and pk methods
- tmpModel := reflect.New(q.modelType)
- // Get the table name
- tblMth := tmpModel.MethodByName("Table")
- if !tblMth.IsValid() {
- q.err = errors.New("ERROR: db.Query.New(), target Table() method must be defined")
- }
- q.tblName = tblMth.Call(nil)[0].String()
- // Get the primary key name
- pkMth := tmpModel.MethodByName("PK")
- if !tblMth.IsValid() {
- q.err = errors.New("ERROR: db.Query.New(), target PK() method must be defined")
- }
- q.pkName = pkMth.Call(nil)[0].String()
- q.ready = true
- }
- func (q *Query) Reset() {
- q.sel = "*"
- q.where = ""
- var new []any
- q.vals = new
- }
- func (q *Query) Select(s string) *Query {
- if q.err != nil {
- return q
- }
- if !q.ready {
- q.err = errors.New("ERROR: db.Query.Select(), db.Query.New() needs to be ran first")
- return q
- }
- q.sel = s
- return q
- }
- func (q *Query) Where(str string, vals ...any) *Query {
- if q.err != nil {
- return q
- }
- if !q.ready {
- q.err = errors.New("ERROR: db.Query.Where(), db.Query.New() needs to run first")
- return q
- }
- q.where = fmt.Sprintf("WHERE (%s) ", str)
- q.vals = append(q.vals, vals...)
- return q
- }
- // Security check here to ensure str is a field (column)
- func (q *Query) Order(str string) *Query {
- q.order += "ORDER BY " + str
- return q
- }
- func (q *Query) Create() (int64, error) {
- if q.err != nil {
- return 0, q.err
- }
- if !q.ready {
- q.err = errors.New("ERROR: db.Query.Create(), db.Query.New() needs to run first")
- return 0, q.err
- }
- if q.targetVal.Kind().String() != "struct" {
- return 0, errors.New("ERROR: db.Query.Create() model must be a struct (are you using a slice)")
- }
- fields, plcHdr, vals := q.getFields(q.targetPtr.Interface())
- for i := 0; i < len(vals); i++ {
- if vals[i] == "<nil>" {
- vals[i] = nil
- }
- }
- query := fmt.Sprintf(
- "INSERT INTO %s (%s) VALUES (%s)", q.tblName, strings.Join(fields, ", "), strings.Join(plcHdr, ", "),
- )
- q.targetVal.FieldByName("Sess").Set(reflect.ValueOf(&q.Sess))
- createRuleMethod := q.targetPtr.MethodByName("CreateRule")
- createRulePassed := createRuleMethod.Call(nil)[0].Bool()
- if !createRulePassed {
- return 0, errors.New("ERROR: db.Query.Create(), create permissions not passed")
- }
- res, err := Conn.Exec(query, vals...)
- if err != nil {
- stringVals := make([]string, len(vals))
- for i, v := range vals {
- stringVals[i] = fmt.Sprint(v)
- }
- fmt.Printf("QUERY ERROR: db.Create(), Query: `%s` Values: `%s` Error: %s ", query, strings.Join(stringVals, ","), err.Error())
- return 0, err
- }
- lastId, _ := res.LastInsertId()
- q.targetVal.FieldByName(q.pkName).Set(reflect.ValueOf(int(lastId)))
- return lastId, nil
- }
- func (q *Query) Read(id int) error {
- if q.err != nil {
- return q.err
- }
- if !q.ready {
- q.err = errors.New("ERROR: db.Query.Read(), db.Query.New() needs to run first")
- return q.err
- }
- if q.targetVal.Kind().String() != "struct" {
- return errors.New("ERROR: db.Query.Read(), model must be a struct (are you using a slice)")
- }
- err := q.Where(q.pkName+" = ?", id).Get()
- if err != nil {
- return err
- }
- return nil
- }
- func (q *Query) Update() error {
- if q.err != nil {
- return q.err
- }
- if !q.ready {
- q.err = errors.New("ERROR: db.Query.Update(), db.Query.New() needs to run first")
- return q.err
- }
- if q.targetVal.Kind().String() != "struct" {
- return errors.New("ERROR: db.Query.Update() model must be a struct (are you using a slice)")
- }
- // Check the existing data from db to check rule
- id := q.targetVal.FieldByName(q.pkName)
- tmpModel := reflect.New(q.modelType)
- tmpQuery := Query{}
- q.err = tmpQuery.New(tmpModel.Interface(), q.Sess).Read(int(id.Int()))
- if q.err != nil {
- return q.err
- }
- existRuleMethod := tmpModel.MethodByName("UpdateRule")
- existRulePassed := existRuleMethod.Call(nil)[0].Bool()
- newRuleMethod := q.targetPtr.MethodByName("UpdateRule")
- newRulePassed := newRuleMethod.Call(nil)[0].Bool()
- if !existRulePassed || !newRulePassed {
- return errors.New("ERROR: db.Query.Update() ReadRule not passed)")
- }
- fields, _, vals := q.getFields(q.targetInt)
- for i := 0; i < len(fields); i++ {
- fields[i] = fields[i] + "=?"
- }
- for i := 0; i < len(vals); i++ {
- if vals[i] == "<nil>" {
- vals[i] = nil
- }
- }
- vals = append(vals, int(id.Int()))
- query := fmt.Sprintf(
- "UPDATE %s SET %s WHERE %s=?", q.tblName, strings.Join(fields, ", "), q.pkName,
- )
- _, err := Conn.Exec(query, vals...)
- if err != nil {
- fmt.Printf("--QUERY ERROR (db.Update()): `%s` %s ", query, err.Error())
- fmt.Println("\nQuery Values: ", vals)
- q.Reset()
- return err
- }
- return nil
- }
- func (q *Query) Delete() error {
- if q.err != nil {
- return q.err
- }
- if !q.ready {
- q.err = errors.New("db.Query.New() needs to run first")
- return q.err
- }
- if q.targetVal.Kind().String() != "struct" {
- return errors.New("db.Query.Delete() model must be a struct (are you trying a slice or array?)")
- }
- // Check the existing data from db to check rule
- id := q.targetVal.FieldByName(q.pkName)
- tmpModel := reflect.New(q.targetVal.Type())
- tmpQuery := Query{}
- tmpQuery.New(tmpModel.Interface(), q.Sess).Read(int(id.Int()))
- existRuleMethod := tmpModel.MethodByName("DeleteRule")
- existRulePassed := existRuleMethod.Call(nil)[0].Bool()
- if !existRulePassed {
- err := "Update rule failed"
- fmt.Println(err)
- return errors.New(err)
- }
- query := fmt.Sprintf("DELETE FROM %s WHERE %s = ?", q.tblName, q.pkName)
- _, err := Conn.Exec(query, int(id.Int()))
- if err != nil {
- return errors.New("db.Query.Get(), " + err.Error())
- }
- q.Reset()
- return nil
- }
- func (q *Query) Get() error {
- q.targetVal.FieldByName("Sess").Set(reflect.ValueOf(&q.Sess))
- // Create a temp model to store the data
- tmpModel := reflect.New(q.modelType)
- tmpModelElem := tmpModel.Elem()
- tmpModelElem.FieldByName("Sess").Set(reflect.ValueOf(&q.Sess))
- //
- query := fmt.Sprintf("SELECT %s FROM %s %s %s", q.sel, q.tblName, q.where, q.order)
- err := Conn.QueryRow(query, q.vals...).Scan(StructPointers(tmpModel.Interface())...)
- if err != nil {
- return errors.New("db.Query.Get(), " + err.Error())
- }
- readRuleMethod := tmpModel.MethodByName("ReadRule")
- readRulePassed := readRuleMethod.Call(nil)[0].Bool()
- if !readRulePassed {
- return errors.New("ERROR: db.Query.Get(), read rules not passed")
- }
- for i := 0; i < tmpModelElem.NumField(); i++ {
- q.targetVal.Field(i).Set(tmpModelElem.Field(i))
- }
- return nil
- }
- func (q *Query) All() error {
- query := fmt.Sprintf("SELECT %s FROM %s %s %s", q.sel, q.tblName, q.where, q.order)
- rows, err := Conn.Query(query, q.vals...)
- if err != nil {
- return err
- }
- var set []reflect.Value
- for rows.Next() {
- tmpModel := reflect.New(q.modelType)
- tmpModel.Elem().FieldByName("Sess").Set(reflect.ValueOf(&q.Sess))
- rows.Scan(StructPointers(tmpModel.Interface())...)
- readRuleMethod := tmpModel.MethodByName("ReadRule")
- readRulePassed := readRuleMethod.Call(nil)[0].Bool()
- if !readRulePassed {
- return errors.New("ERROR: db.Query.All(), read rules not passed")
- }
- set = append(set, tmpModel.Elem())
- }
- rows.Close()
- newSlice := reflect.Append(q.targetVal, set...)
- q.targetVal.Set(newSlice)
- return nil
- }
- func (q *Query) getFields(s interface{}) ([]string, []string, []any) {
- targetPtr := reflect.ValueOf(s)
- targetVal := targetPtr.Elem()
- numFields := targetVal.NumField()
- var fields []string
- var plcHdr []string
- var vals []any
- for i := 0; i < numFields; i++ {
- fieldName := targetVal.Type().Field(i).Name
- if fieldName == q.pkName {
- continue
- }
- fieldVal := targetVal.Field(i)
- tag := targetVal.Type().Field(i).Tag
- if tag != "" {
- continue
- }
- if fieldVal.Kind() == reflect.Struct {
- fieldsVal, plcHdrVal, valsVal := q.getFields(fieldVal.Addr().Interface())
- fields = append(fields, fieldsVal...)
- plcHdr = append(plcHdr, plcHdrVal...)
- vals = append(vals, valsVal...)
- continue
- }
- fields = append(fields, fmt.Sprintf("%v", fieldName))
- plcHdr = append(plcHdr, "?")
- vals = append(vals, fmt.Sprintf("%v", fieldVal))
- }
- return fields, plcHdr, vals
- }
- func (q *Query) SortParams(r *http.Request, defaults *Pagination) {
- if r.FormValue("prevPage") != "" {
- defaults.PrevPage = r.FormValue("prevPage")
- }
- if r.FormValue("nextPage") != "" {
- defaults.NextPage = r.FormValue("nextPage")
- }
- if r.FormValue("page") != "" {
- defaults.Page, _ = strconv.Atoi(r.FormValue("page"))
- }
- if r.FormValue("perPage") != "" {
- defaults.PerPage, _ = strconv.Atoi(r.FormValue("perPage"))
- }
- if r.FormValue("order") != "" {
- defaults.Order = r.FormValue("order")
- }
- if r.FormValue("direction") != "" {
- defaults.Direction = r.FormValue("direction")
- }
- defaults.Search = r.FormValue("search")
- }
- func (q *Query) Sort(defaults *Pagination) *Query {
- var count int
- query := fmt.Sprintf("select count(*) as count from %s %s", q.tblName, q.where)
- _ = Conn.QueryRow(query, q.vals...).Scan(&count)
- pages := math.Ceil(float64(count) / float64(defaults.PerPage))
- defaults.Pages = int(pages)
- if defaults.PrevPage != "" {
- defaults.Page--
- }
- if defaults.NextPage != "" {
- defaults.Page++
- }
- if defaults.Page <= 0 {
- defaults.Page = 1
- }
- if defaults.Page > defaults.Pages {
- defaults.Page = defaults.Pages
- }
- defaults.Start = (defaults.Page - 1) * defaults.PerPage
- if defaults.Start < 0 {
- defaults.Start = 0
- }
- q.Order(defaults.Order + " " + defaults.Direction)
- // this.limit(sort.start, sort.perPage);
- return q
- }
|