package db import ( "errors" "fmt" "math" "net/http" "reflect" "strconv" "strings" "git.clearsky.net.au/cody/gex.git/sec" ) type Pagination struct { PrevPage string NextPage string Page int Pages int PerPage int Order string Direction string Search string Start int } type Query struct { targetInt interface{} targetPtr reflect.Value targetVal reflect.Value modelType reflect.Type tblName string pkName string Sess sec.Sess err error ready bool sel string where string order string vals []any } func (q *Query) New(targetInt interface{}, sess sec.Sess) *Query { q.targetInt = targetInt q.targetPtr = reflect.ValueOf(q.targetInt) if q.targetPtr.Kind().String() != "ptr" { q.err = errors.New("ERROR: db.Query.New(), targetInt must be a pointer") return q } q.targetVal = q.targetPtr.Elem() q.sel = "*" // Get the session q.Sess = sess if q.targetVal.Kind().String() == "slice" { q.newSlice() return q } //Get the type of the model q.modelType = q.targetVal.Type() // Get the table name tblMth := q.targetPtr.MethodByName("Table") if !tblMth.IsValid() { q.err = errors.New("ERROR: db.Query.New(), target Table() method must be defined") return q } q.tblName = tblMth.Call(nil)[0].String() // Get the primary key name pkMth := q.targetPtr.MethodByName("PK") if !tblMth.IsValid() { q.err = errors.New("ERROR: db.Query.New(), target PK() method must be defined") return q } q.pkName = pkMth.Call(nil)[0].String() q.ready = true return q } func (q *Query) newSlice() { // Get the type of the model q.modelType = q.targetVal.Type().Elem() // create a new model so we can run the table and pk methods tmpModel := reflect.New(q.modelType) // Get the table name tblMth := tmpModel.MethodByName("Table") if !tblMth.IsValid() { q.err = errors.New("ERROR: db.Query.New(), target Table() method must be defined") } q.tblName = tblMth.Call(nil)[0].String() // Get the primary key name pkMth := tmpModel.MethodByName("PK") if !tblMth.IsValid() { q.err = errors.New("ERROR: db.Query.New(), target PK() method must be defined") } q.pkName = pkMth.Call(nil)[0].String() q.ready = true } func (q *Query) Reset() { q.sel = "*" q.where = "" var new []any q.vals = new } func (q *Query) Select(s string) *Query { if q.err != nil { return q } if !q.ready { q.err = errors.New("ERROR: db.Query.Select(), db.Query.New() needs to be ran first") return q } q.sel = s return q } func (q *Query) Where(str string, vals ...any) *Query { if q.err != nil { return q } if !q.ready { q.err = errors.New("ERROR: db.Query.Where(), db.Query.New() needs to run first") return q } q.where = fmt.Sprintf("WHERE (%s) ", str) q.vals = append(q.vals, vals...) return q } // Security check here to ensure str is a field (column) func (q *Query) Order(str string) *Query { q.order += "ORDER BY " + str return q } func (q *Query) Create() (int64, error) { if q.err != nil { return 0, q.err } if !q.ready { q.err = errors.New("ERROR: db.Query.Create(), db.Query.New() needs to run first") return 0, q.err } if q.targetVal.Kind().String() != "struct" { return 0, errors.New("ERROR: db.Query.Create() model must be a struct (are you using a slice)") } fields, plcHdr, vals := q.getFields(q.targetPtr.Interface()) for i := 0; i < len(vals); i++ { if vals[i] == "" { vals[i] = nil } } query := fmt.Sprintf( "INSERT INTO %s (%s) VALUES (%s)", q.tblName, strings.Join(fields, ", "), strings.Join(plcHdr, ", "), ) q.targetVal.FieldByName("Sess").Set(reflect.ValueOf(&q.Sess)) createRuleMethod := q.targetPtr.MethodByName("CreateRule") createRulePassed := createRuleMethod.Call(nil)[0].Bool() if !createRulePassed { return 0, errors.New("ERROR: db.Query.Create(), create permissions not passed") } res, err := Conn.Exec(query, vals...) if err != nil { stringVals := make([]string, len(vals)) for i, v := range vals { stringVals[i] = fmt.Sprint(v) } fmt.Printf("QUERY ERROR: db.Create(), Query: `%s` Values: `%s` Error: %s ", query, strings.Join(stringVals, ","), err.Error()) return 0, err } lastId, _ := res.LastInsertId() q.targetVal.FieldByName(q.pkName).Set(reflect.ValueOf(int(lastId))) return lastId, nil } func (q *Query) Read(id int) error { if q.err != nil { return q.err } if !q.ready { q.err = errors.New("ERROR: db.Query.Read(), db.Query.New() needs to run first") return q.err } if q.targetVal.Kind().String() != "struct" { return errors.New("ERROR: db.Query.Read(), model must be a struct (are you using a slice)") } err := q.Where(q.pkName+" = ?", id).Get() if err != nil { return err } return nil } func (q *Query) Update() error { if q.err != nil { return q.err } if !q.ready { q.err = errors.New("ERROR: db.Query.Update(), db.Query.New() needs to run first") return q.err } if q.targetVal.Kind().String() != "struct" { return errors.New("ERROR: db.Query.Update() model must be a struct (are you using a slice)") } // Check the existing data from db to check rule id := q.targetVal.FieldByName(q.pkName) tmpModel := reflect.New(q.modelType) tmpQuery := Query{} q.err = tmpQuery.New(tmpModel.Interface(), q.Sess).Read(int(id.Int())) if q.err != nil { return q.err } existRuleMethod := tmpModel.MethodByName("UpdateRule") existRulePassed := existRuleMethod.Call(nil)[0].Bool() newRuleMethod := q.targetPtr.MethodByName("UpdateRule") newRulePassed := newRuleMethod.Call(nil)[0].Bool() if !existRulePassed || !newRulePassed { return errors.New("ERROR: db.Query.Update() ReadRule not passed)") } fields, _, vals := q.getFields(q.targetInt) for i := 0; i < len(fields); i++ { fields[i] = fields[i] + "=?" } for i := 0; i < len(vals); i++ { if vals[i] == "" { vals[i] = nil } } vals = append(vals, int(id.Int())) query := fmt.Sprintf( "UPDATE %s SET %s WHERE %s=?", q.tblName, strings.Join(fields, ", "), q.pkName, ) _, err := Conn.Exec(query, vals...) if err != nil { fmt.Printf("--QUERY ERROR (db.Update()): `%s` %s ", query, err.Error()) fmt.Println("\nQuery Values: ", vals) q.Reset() return err } return nil } func (q *Query) Delete() error { if q.err != nil { return q.err } if !q.ready { q.err = errors.New("db.Query.New() needs to run first") return q.err } if q.targetVal.Kind().String() != "struct" { return errors.New("db.Query.Delete() model must be a struct (are you trying a slice or array?)") } // Check the existing data from db to check rule id := q.targetVal.FieldByName(q.pkName) tmpModel := reflect.New(q.targetVal.Type()) tmpQuery := Query{} tmpQuery.New(tmpModel.Interface(), q.Sess).Read(int(id.Int())) existRuleMethod := tmpModel.MethodByName("DeleteRule") existRulePassed := existRuleMethod.Call(nil)[0].Bool() if !existRulePassed { err := "Update rule failed" fmt.Println(err) return errors.New(err) } query := fmt.Sprintf("DELETE FROM %s WHERE %s = ?", q.tblName, q.pkName) _, err := Conn.Exec(query, int(id.Int())) if err != nil { return errors.New("db.Query.Get(), " + err.Error()) } q.Reset() return nil } func (q *Query) Get() error { q.targetVal.FieldByName("Sess").Set(reflect.ValueOf(&q.Sess)) // Create a temp model to store the data tmpModel := reflect.New(q.modelType) tmpModelElem := tmpModel.Elem() tmpModelElem.FieldByName("Sess").Set(reflect.ValueOf(&q.Sess)) // query := fmt.Sprintf("SELECT %s FROM %s %s %s", q.sel, q.tblName, q.where, q.order) err := Conn.QueryRow(query, q.vals...).Scan(StructPointers(tmpModel.Interface())...) if err != nil { return errors.New("db.Query.Get(), " + err.Error()) } readRuleMethod := tmpModel.MethodByName("ReadRule") readRulePassed := readRuleMethod.Call(nil)[0].Bool() if !readRulePassed { return errors.New("ERROR: db.Query.Get(), read rules not passed") } for i := 0; i < tmpModelElem.NumField(); i++ { q.targetVal.Field(i).Set(tmpModelElem.Field(i)) } return nil } func (q *Query) All() error { query := fmt.Sprintf("SELECT %s FROM %s %s %s", q.sel, q.tblName, q.where, q.order) rows, err := Conn.Query(query, q.vals...) if err != nil { return err } var set []reflect.Value for rows.Next() { tmpModel := reflect.New(q.modelType) tmpModel.Elem().FieldByName("Sess").Set(reflect.ValueOf(&q.Sess)) rows.Scan(StructPointers(tmpModel.Interface())...) readRuleMethod := tmpModel.MethodByName("ReadRule") readRulePassed := readRuleMethod.Call(nil)[0].Bool() if !readRulePassed { return errors.New("ERROR: db.Query.All(), read rules not passed") } set = append(set, tmpModel.Elem()) } rows.Close() newSlice := reflect.Append(q.targetVal, set...) q.targetVal.Set(newSlice) return nil } func (q *Query) getFields(s interface{}) ([]string, []string, []any) { targetPtr := reflect.ValueOf(s) targetVal := targetPtr.Elem() numFields := targetVal.NumField() var fields []string var plcHdr []string var vals []any for i := 0; i < numFields; i++ { fieldName := targetVal.Type().Field(i).Name if fieldName == q.pkName { continue } fieldVal := targetVal.Field(i) tag := targetVal.Type().Field(i).Tag if tag != "" { continue } if fieldVal.Kind() == reflect.Struct { fieldsVal, plcHdrVal, valsVal := q.getFields(fieldVal.Addr().Interface()) fields = append(fields, fieldsVal...) plcHdr = append(plcHdr, plcHdrVal...) vals = append(vals, valsVal...) continue } fields = append(fields, fmt.Sprintf("%v", fieldName)) plcHdr = append(plcHdr, "?") vals = append(vals, fmt.Sprintf("%v", fieldVal)) } return fields, plcHdr, vals } func (q *Query) SortParams(r *http.Request, defaults *Pagination) { if r.FormValue("prevPage") != "" { defaults.PrevPage = r.FormValue("prevPage") } if r.FormValue("nextPage") != "" { defaults.NextPage = r.FormValue("nextPage") } if r.FormValue("page") != "" { defaults.Page, _ = strconv.Atoi(r.FormValue("page")) } if r.FormValue("perPage") != "" { defaults.PerPage, _ = strconv.Atoi(r.FormValue("perPage")) } if r.FormValue("order") != "" { defaults.Order = r.FormValue("order") } if r.FormValue("direction") != "" { defaults.Direction = r.FormValue("direction") } defaults.Search = r.FormValue("search") } func (q *Query) Sort(defaults *Pagination) *Query { var count int query := fmt.Sprintf("select count(*) as count from %s %s", q.tblName, q.where) _ = Conn.QueryRow(query, q.vals...).Scan(&count) pages := math.Ceil(float64(count) / float64(defaults.PerPage)) defaults.Pages = int(pages) if defaults.PrevPage != "" { defaults.Page-- } if defaults.NextPage != "" { defaults.Page++ } if defaults.Page <= 0 { defaults.Page = 1 } if defaults.Page > defaults.Pages { defaults.Page = defaults.Pages } defaults.Start = (defaults.Page - 1) * defaults.PerPage if defaults.Start < 0 { defaults.Start = 0 } q.Order(defaults.Order + " " + defaults.Direction) // this.limit(sort.start, sort.perPage); return q }