From 2d05b5a80f6a83f47181069747ee0e81834ae81a Mon Sep 17 00:00:00 2001 From: roberChen Date: Thu, 25 Mar 2021 10:42:18 +0800 Subject: [PATCH] fix(config): use web.AppConfig instead of adapter - use web.AppConfig, and change related codes TODO: handles errors --- commands/command.go | 62 +++++++++++++++---------------- commands/install.go | 7 +--- commands/migrate/migrate.go | 6 +-- conf/enumerate.go | 41 ++++++++++---------- conf/mail.go | 18 ++++----- controllers/BaseController.go | 4 +- controllers/DocumentController.go | 4 +- models/Blog.go | 6 +-- models/DocumentModel.go | 4 +- models/Member.go | 30 +++++++++------ routers/filter.go | 5 ++- 11 files changed, 95 insertions(+), 92 deletions(-) diff --git a/commands/command.go b/commands/command.go index 2946bc83..410a2c4c 100644 --- a/commands/command.go +++ b/commands/command.go @@ -34,17 +34,17 @@ import ( // RegisterDataBase 注册数据库 func RegisterDataBase() { logs.Info("正在初始化数据库配置.") - dbadapter := adapter.AppConfig.String("db_adapter") + dbadapter,_ := web.AppConfig.String("db_adapter") orm.DefaultTimeLoc = time.Local orm.DefaultRowsLimit = -1 if strings.EqualFold(dbadapter, "mysql") { - host := adapter.AppConfig.String("db_host") - database := adapter.AppConfig.String("db_database") - username := adapter.AppConfig.String("db_username") - password := adapter.AppConfig.String("db_password") + host,_ := web.AppConfig.String("db_host") + database, _ := web.AppConfig.String("db_database") + username,_ := web.AppConfig.String("db_username") + password,_ := web.AppConfig.String("db_password") - timezone := adapter.AppConfig.String("timezone") + timezone,_ := web.AppConfig.String("timezone") location, err := time.LoadLocation(timezone) if err == nil { orm.DefaultTimeLoc = location @@ -52,7 +52,7 @@ func RegisterDataBase() { logs.Error("加载时区配置信息失败,请检查是否存在 ZONEINFO 环境变量->", err) } - port := adapter.AppConfig.String("db_port") + port,_ := web.AppConfig.String("db_port") dataSource := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=true&loc=%s", username, password, host, port, database, url.QueryEscape(timezone)) @@ -63,7 +63,7 @@ func RegisterDataBase() { } else if strings.EqualFold(dbadapter, "sqlite3") { - database := adapter.AppConfig.String("db_database") + database,_ := web.AppConfig.String("db_database") if strings.HasPrefix(database, "./") { database = filepath.Join(conf.WorkingDirectory, string(database[1:])) } @@ -123,11 +123,11 @@ func RegisterLogger(log string) { _ = logs.SetLogger("console") logs.EnableFuncCallDepth(true) - if adapter.AppConfig.DefaultBool("log_is_async", true) { + if web.AppConfig.DefaultBool("log_is_async", true) { logs.Async(1e3) } if log == "" { - logPath, err := filepath.Abs(adapter.AppConfig.DefaultString("log_path", conf.WorkingDir("runtime", "logs"))) + logPath, err := filepath.Abs(web.AppConfig.DefaultString("log_path", conf.WorkingDir("runtime", "logs"))) if err == nil { log = logPath } else { @@ -147,19 +147,19 @@ func RegisterLogger(log string) { config["perm"] = "0755" config["rotate"] = true - if maxLines := adapter.AppConfig.DefaultInt("log_maxlines", 1000000); maxLines > 0 { + if maxLines := web.AppConfig.DefaultInt("log_maxlines", 1000000); maxLines > 0 { config["maxLines"] = maxLines } - if maxSize := adapter.AppConfig.DefaultInt("log_maxsize", 1<<28); maxSize > 0 { + if maxSize := web.AppConfig.DefaultInt("log_maxsize", 1<<28); maxSize > 0 { config["maxsize"] = maxSize } - if !adapter.AppConfig.DefaultBool("log_daily", true) { + if !web.AppConfig.DefaultBool("log_daily", true) { config["daily"] = false } - if maxDays := adapter.AppConfig.DefaultInt("log_maxdays", 7); maxDays > 0 { + if maxDays := web.AppConfig.DefaultInt("log_maxdays", 7); maxDays > 0 { config["maxdays"] = maxDays } - if level := adapter.AppConfig.DefaultString("log_level", "Trace"); level != "" { + if level := web.AppConfig.DefaultString("log_level", "Trace"); level != "" { switch level { case "Emergency": config["level"] = logs.LevelEmergency @@ -220,13 +220,13 @@ func RegisterFunction() { os.Exit(-1) } err = adapter.AddFuncMap("cdn", func(p string) string { - cdn := adapter.AppConfig.DefaultString("cdn", "") + cdn := web.AppConfig.DefaultString("cdn", "") if strings.HasPrefix(p, "http://") || strings.HasPrefix(p, "https://") { return p } //如果没有设置cdn,则使用baseURL拼接 if cdn == "" { - baseUrl := adapter.AppConfig.DefaultString("baseurl", "") + baseUrl := web.AppConfig.DefaultString("baseurl", "") if strings.HasPrefix(p, "/") && strings.HasSuffix(baseUrl, "/") { return baseUrl + p[1:] @@ -311,7 +311,7 @@ func ResolveCommand(args []string) { log.Fatal("An error occurred:", err) } if conf.LogFile == "" { - logPath, err := filepath.Abs(adapter.AppConfig.DefaultString("log_path", conf.WorkingDir("runtime", "logs"))) + logPath, err := filepath.Abs(web.AppConfig.DefaultString("log_path", conf.WorkingDir("runtime", "logs"))) if err == nil { conf.LogFile = logPath } else { @@ -319,7 +319,7 @@ func ResolveCommand(args []string) { } } - conf.AutoLoadDelay = adapter.AppConfig.DefaultInt("config_auto_delay", 0) + conf.AutoLoadDelay = web.AppConfig.DefaultInt("config_auto_delay", 0) uploads := conf.WorkingDir("uploads") _ = os.MkdirAll(uploads, 0666) @@ -348,15 +348,15 @@ func ResolveCommand(args []string) { //注册缓存管道 func RegisterCache() { - isOpenCache := adapter.AppConfig.DefaultBool("cache", false) + isOpenCache := web.AppConfig.DefaultBool("cache", false) if !isOpenCache { cache.Init(&cache.NullCache{}) return } logs.Info("正常初始化缓存配置.") - cacheProvider := adapter.AppConfig.String("cache_provider") + cacheProvider, _ := web.AppConfig.String("cache_provider") if cacheProvider == "file" { - cacheFilePath := adapter.AppConfig.DefaultString("cache_file_path", "./runtime/cache/") + cacheFilePath := web.AppConfig.DefaultString("cache_file_path", "./runtime/cache/") if strings.HasPrefix(cacheFilePath, "./") { cacheFilePath = filepath.Join(conf.WorkingDirectory, string(cacheFilePath[1:])) } @@ -365,9 +365,9 @@ func RegisterCache() { fileConfig := make(map[string]string, 0) fileConfig["CachePath"] = cacheFilePath - fileConfig["DirectoryLevel"] = adapter.AppConfig.DefaultString("cache_file_dir_level", "2") - fileConfig["EmbedExpiry"] = adapter.AppConfig.DefaultString("cache_file_expiry", "120") - fileConfig["FileSuffix"] = adapter.AppConfig.DefaultString("cache_file_suffix", ".bin") + fileConfig["DirectoryLevel"] = web.AppConfig.DefaultString("cache_file_dir_level", "2") + fileConfig["EmbedExpiry"] = web.AppConfig.DefaultString("cache_file_expiry", "120") + fileConfig["FileSuffix"] = web.AppConfig.DefaultString("cache_file_suffix", ".bin") bc, err := json.Marshal(&fileConfig) if err != nil { @@ -380,13 +380,13 @@ func RegisterCache() { cache.Init(fileCache) } else if cacheProvider == "memory" { - cacheInterval := adapter.AppConfig.DefaultInt("cache_memory_interval", 60) + cacheInterval := web.AppConfig.DefaultInt("cache_memory_interval", 60) memory := beegoCache.NewMemoryCache() beegoCache.DefaultEvery = cacheInterval cache.Init(memory) } else if cacheProvider == "redis" { //设置Redis前缀 - if key := adapter.AppConfig.DefaultString("cache_redis_prefix", ""); key != "" { + if key := web.AppConfig.DefaultString("cache_redis_prefix", ""); key != "" { redis.DefaultKey = key } var redisConfig struct { @@ -395,11 +395,11 @@ func RegisterCache() { DbNum int `json:"dbNum"` } redisConfig.DbNum = 0 - redisConfig.Conn = adapter.AppConfig.DefaultString("cache_redis_host", "") - if pwd := adapter.AppConfig.DefaultString("cache_redis_password", ""); pwd != "" { + redisConfig.Conn = web.AppConfig.DefaultString("cache_redis_host", "") + if pwd := web.AppConfig.DefaultString("cache_redis_password", ""); pwd != "" { redisConfig.Password = pwd } - if dbNum := adapter.AppConfig.DefaultInt("cache_redis_db", 0); dbNum > 0 { + if dbNum := web.AppConfig.DefaultInt("cache_redis_db", 0); dbNum > 0 { redisConfig.DbNum = dbNum } @@ -421,7 +421,7 @@ func RegisterCache() { var memcacheConfig struct { Conn string `json:"conn"` } - memcacheConfig.Conn = adapter.AppConfig.DefaultString("cache_memcache_host", "") + memcacheConfig.Conn = web.AppConfig.DefaultString("cache_memcache_host", "") bc, err := json.Marshal(&memcacheConfig) if err != nil { diff --git a/commands/install.go b/commands/install.go index 8d962e68..1be2b9a7 100644 --- a/commands/install.go +++ b/commands/install.go @@ -24,7 +24,6 @@ func Install() { initialization() } else { panic(err.Error()) - os.Exit(1) } fmt.Println("Install Successfully!") os.Exit(0) @@ -99,7 +98,6 @@ func initialization() { if err != nil { panic(err.Error()) - os.Exit(1) } member, err := models.NewMember().FindByFieldFirst("account", "admin") @@ -109,12 +107,11 @@ func initialization() { member.Avatar = conf.URLForWithCdnImage("/static/images/headimgurl.jpg") member.Password = "123456" member.AuthMethod = "local" - member.Role = 0 + member.Role = conf.MemberSuperRole member.Email = "admin@iminho.me" if err := member.Add(); err != nil { panic("Member.Add => " + err.Error()) - os.Exit(0) } book := models.NewBook() @@ -137,7 +134,6 @@ func initialization() { if err := book.Insert(); err != nil { panic("初始化项目失败 -> " + err.Error()) - os.Exit(1) } } @@ -147,7 +143,6 @@ func initialization() { item.MemberId = 1 if err := item.Save(); err != nil { panic("初始化项目空间失败 -> " + err.Error()) - os.Exit(1) } } } diff --git a/commands/migrate/migrate.go b/commands/migrate/migrate.go index 5c71a76b..ed269cd7 100644 --- a/commands/migrate/migrate.go +++ b/commands/migrate/migrate.go @@ -21,8 +21,8 @@ import ( "fmt" "log" - "github.com/beego/beego/v2/adapter" "github.com/beego/beego/v2/client/orm" + "github.com/beego/beego/v2/server/web" "github.com/mindoc-org/mindoc/models" ) @@ -114,8 +114,8 @@ func RunMigration() { //导出数据库的表结构 func ExportDatabaseTable() ([]string, error) { - dbadapter := adapter.AppConfig.String("db_adapter") - dbdatabase := adapter.AppConfig.String("db_database") + dbadapter,_ := web.AppConfig.String("db_adapter") + dbdatabase,_ := web.AppConfig.String("db_database") tables := make([]string, 0) o := orm.NewOrm() diff --git a/conf/enumerate.go b/conf/enumerate.go index b863cf04..6b05e20c 100644 --- a/conf/enumerate.go +++ b/conf/enumerate.go @@ -10,6 +10,7 @@ import ( "strconv" "github.com/beego/beego/v2/adapter" + "github.com/beego/beego/v2/server/web" ) // 登录用户的Session名 @@ -81,32 +82,32 @@ var ( // app_key func GetAppKey() string { - return adapter.AppConfig.DefaultString("app_key", "mindoc") + return web.AppConfig.DefaultString("app_key", "mindoc") } func GetDatabasePrefix() string { - return adapter.AppConfig.DefaultString("db_prefix", "md_") + return web.AppConfig.DefaultString("db_prefix", "md_") } //获取默认头像 func GetDefaultAvatar() string { - return URLForWithCdnImage(adapter.AppConfig.DefaultString("avatar", "/static/images/headimgurl.jpg")) + return URLForWithCdnImage(web.AppConfig.DefaultString("avatar", "/static/images/headimgurl.jpg")) } //获取阅读令牌长度. func GetTokenSize() int { - return adapter.AppConfig.DefaultInt("token_size", 12) + return web.AppConfig.DefaultInt("token_size", 12) } //获取默认文档封面. func GetDefaultCover() string { - return URLForWithCdnImage(adapter.AppConfig.DefaultString("cover", "/static/images/book.jpg")) + return URLForWithCdnImage(web.AppConfig.DefaultString("cover", "/static/images/book.jpg")) } //获取允许的商城文件的类型. func GetUploadFileExt() []string { - ext := adapter.AppConfig.DefaultString("upload_file_ext", "png|jpg|jpeg|gif|txt|doc|docx|pdf") + ext := web.AppConfig.DefaultString("upload_file_ext", "png|jpg|jpeg|gif|txt|doc|docx|pdf") temp := strings.Split(ext, "|") @@ -124,7 +125,7 @@ func GetUploadFileExt() []string { // 获取上传文件允许的最大值 func GetUploadFileSize() int64 { - size := adapter.AppConfig.DefaultString("upload_file_size", "0") + size := web.AppConfig.DefaultString("upload_file_size", "0") if strings.HasSuffix(size, "MB") { if s, e := strconv.ParseInt(size[0:len(size)-2], 10, 64); e == nil { @@ -149,12 +150,12 @@ func GetUploadFileSize() int64 { //是否启用导出 func GetEnableExport() bool { - return adapter.AppConfig.DefaultBool("enable_export", true) + return web.AppConfig.DefaultBool("enable_export", true) } //同一项目导出线程的并发数 func GetExportProcessNum() int { - exportProcessNum := adapter.AppConfig.DefaultInt("export_process_num", 1) + exportProcessNum := web.AppConfig.DefaultInt("export_process_num", 1) if exportProcessNum <= 0 || exportProcessNum > 4 { exportProcessNum = 1 @@ -164,7 +165,7 @@ func GetExportProcessNum() int { //导出项目队列的并发数量 func GetExportLimitNum() int { - exportLimitNum := adapter.AppConfig.DefaultInt("export_limit_num", 1) + exportLimitNum := web.AppConfig.DefaultInt("export_limit_num", 1) if exportLimitNum < 0 { exportLimitNum = 1 @@ -174,7 +175,7 @@ func GetExportLimitNum() int { //等待导出队列的长度 func GetExportQueueLimitNum() int { - exportQueueLimitNum := adapter.AppConfig.DefaultInt("export_queue_limit_num", 10) + exportQueueLimitNum := web.AppConfig.DefaultInt("export_queue_limit_num", 10) if exportQueueLimitNum <= 0 { exportQueueLimitNum = 100 @@ -184,7 +185,7 @@ func GetExportQueueLimitNum() int { //默认导出项目的缓存目录 func GetExportOutputPath() string { - exportOutputPath := filepath.Join(adapter.AppConfig.DefaultString("export_output_path", filepath.Join(WorkingDirectory, "cache")), "books") + exportOutputPath := filepath.Join(web.AppConfig.DefaultString("export_output_path", filepath.Join(WorkingDirectory, "cache")), "books") return exportOutputPath } @@ -210,7 +211,7 @@ func IsAllowUploadFileExt(ext string) bool { //重写生成URL的方法,加上完整的域名 func URLFor(endpoint string, values ...interface{}) string { - baseUrl := adapter.AppConfig.DefaultString("baseurl", "") + baseUrl := web.AppConfig.DefaultString("baseurl", "") pathUrl := adapter.URLFor(endpoint, values...) if baseUrl == "" { @@ -229,7 +230,7 @@ func URLFor(endpoint string, values ...interface{}) string { } func URLForNotHost(endpoint string, values ...interface{}) string { - baseUrl := adapter.AppConfig.DefaultString("baseurl", "") + baseUrl := web.AppConfig.DefaultString("baseurl", "") pathUrl := adapter.URLFor(endpoint, values...) if baseUrl == "" { @@ -251,10 +252,10 @@ func URLForWithCdnImage(p string) string { if strings.HasPrefix(p, "http://") || strings.HasPrefix(p, "https://") { return p } - cdn := adapter.AppConfig.DefaultString("cdnimg", "") + cdn := web.AppConfig.DefaultString("cdnimg", "") //如果没有设置cdn,则使用baseURL拼接 if cdn == "" { - baseUrl := adapter.AppConfig.DefaultString("baseurl", "/") + baseUrl := web.AppConfig.DefaultString("baseurl", "/") if strings.HasPrefix(p, "/") && strings.HasSuffix(baseUrl, "/") { return baseUrl + p[1:] @@ -274,7 +275,7 @@ func URLForWithCdnImage(p string) string { } func URLForWithCdnCss(p string, v ...string) string { - cdn := adapter.AppConfig.DefaultString("cdncss", "") + cdn := web.AppConfig.DefaultString("cdncss", "") if strings.HasPrefix(p, "http://") || strings.HasPrefix(p, "https://") { return p } @@ -285,7 +286,7 @@ func URLForWithCdnCss(p string, v ...string) string { } //如果没有设置cdn,则使用baseURL拼接 if cdn == "" { - baseUrl := adapter.AppConfig.DefaultString("baseurl", "/") + baseUrl := web.AppConfig.DefaultString("baseurl", "/") if strings.HasPrefix(p, "/") && strings.HasSuffix(baseUrl, "/") { return baseUrl + p[1:] @@ -305,7 +306,7 @@ func URLForWithCdnCss(p string, v ...string) string { } func URLForWithCdnJs(p string, v ...string) string { - cdn := adapter.AppConfig.DefaultString("cdnjs", "") + cdn := web.AppConfig.DefaultString("cdnjs", "") if strings.HasPrefix(p, "http://") || strings.HasPrefix(p, "https://") { return p } @@ -318,7 +319,7 @@ func URLForWithCdnJs(p string, v ...string) string { //如果没有设置cdn,则使用baseURL拼接 if cdn == "" { - baseUrl := adapter.AppConfig.DefaultString("baseurl", "/") + baseUrl := web.AppConfig.DefaultString("baseurl", "/") if strings.HasPrefix(p, "/") && strings.HasSuffix(baseUrl, "/") { return baseUrl + p[1:] diff --git a/conf/mail.go b/conf/mail.go index 4cb8183c..120f0c72 100644 --- a/conf/mail.go +++ b/conf/mail.go @@ -3,7 +3,7 @@ package conf import ( "strings" - "github.com/beego/beego/v2/adapter" + "github.com/beego/beego/v2/server/web" ) type SmtpConf struct { @@ -19,14 +19,14 @@ type SmtpConf struct { } func GetMailConfig() *SmtpConf { - user_name := adapter.AppConfig.String("smtp_user_name") - password := adapter.AppConfig.String("smtp_password") - smtp_host := adapter.AppConfig.String("smtp_host") - smtp_port := adapter.AppConfig.DefaultInt("smtp_port", 25) - form_user_name := adapter.AppConfig.String("form_user_name") - enable_mail := adapter.AppConfig.String("enable_mail") - mail_number := adapter.AppConfig.DefaultInt("mail_number", 5) - secure := adapter.AppConfig.DefaultString("secure", "NONE") + user_name, _ := web.AppConfig.String("smtp_user_name") + password, _ := web.AppConfig.String("smtp_password") + smtp_host, _ := web.AppConfig.String("smtp_host") + smtp_port := web.AppConfig.DefaultInt("smtp_port", 25) + form_user_name, _ := web.AppConfig.String("form_user_name") + enable_mail, _ := web.AppConfig.String("enable_mail") + mail_number := web.AppConfig.DefaultInt("mail_number", 5) + secure := web.AppConfig.DefaultString("secure", "NONE") if secure != "NONE" && secure != "LOGIN" && secure != "SSL" { secure = "NONE" diff --git a/controllers/BaseController.go b/controllers/BaseController.go index 18ebf972..4cd778ed 100644 --- a/controllers/BaseController.go +++ b/controllers/BaseController.go @@ -75,7 +75,7 @@ func (c *BaseController) Prepare() { c.EnableAnonymous = strings.EqualFold(c.Option["ENABLE_ANONYMOUS"], "true") c.EnableDocumentHistory = strings.EqualFold(c.Option["ENABLE_DOCUMENT_HISTORY"], "true") } - c.Data["HighlightStyle"] = adapter.AppConfig.DefaultString("highlight_style", "github") + c.Data["HighlightStyle"] = web.AppConfig.DefaultString("highlight_style", "github") if b, err := ioutil.ReadFile(filepath.Join(web.BConfig.WebConfig.ViewsPath, "widgets", "scripts.tpl")); err == nil { c.Data["Scripts"] = template.HTML(string(b)) @@ -167,7 +167,7 @@ func (c *BaseController) ExecuteViewPathTemplate(tplName string, data interface{ } func (c *BaseController) BaseUrl() string { - baseUrl := adapter.AppConfig.DefaultString("baseurl", "") + baseUrl := web.AppConfig.DefaultString("baseurl", "") if baseUrl != "" { if strings.HasSuffix(baseUrl, "/") { baseUrl = strings.TrimSuffix(baseUrl, "/") diff --git a/controllers/DocumentController.go b/controllers/DocumentController.go index 4c176741..b70aaa28 100644 --- a/controllers/DocumentController.go +++ b/controllers/DocumentController.go @@ -15,9 +15,9 @@ import ( "strings" "time" - "github.com/beego/beego/v2/adapter" "github.com/beego/beego/v2/adapter/orm" "github.com/beego/beego/v2/core/logs" + "github.com/beego/beego/v2/server/web" "github.com/boombuler/barcode" "github.com/boombuler/barcode/qr" "github.com/mindoc-org/mindoc/conf" @@ -243,7 +243,7 @@ func (c *DocumentController) Edit() { } } - c.Data["BaiDuMapKey"] = adapter.AppConfig.DefaultString("baidumapkey", "") + c.Data["BaiDuMapKey"] = web.AppConfig.DefaultString("baidumapkey", "") if conf.GetUploadFileSize() > 0 { c.Data["UploadFileSize"] = conf.GetUploadFileSize() diff --git a/models/Blog.go b/models/Blog.go index 245c093d..5f0da5ac 100644 --- a/models/Blog.go +++ b/models/Blog.go @@ -7,9 +7,9 @@ import ( "time" "github.com/PuerkitoBio/goquery" - "github.com/beego/beego/v2/core/logs" - "github.com/beego/beego/v2/adapter" "github.com/beego/beego/v2/adapter/orm" + "github.com/beego/beego/v2/core/logs" + "github.com/beego/beego/v2/server/web" "github.com/mindoc-org/mindoc/cache" "github.com/mindoc-org/mindoc/conf" "github.com/mindoc-org/mindoc/utils" @@ -273,7 +273,7 @@ func (b *Blog) Processor() *Blog { } }) //设置图片为CDN地址 - if cdnimg := adapter.AppConfig.String("cdnimg"); cdnimg != "" { + if cdnimg,_ := web.AppConfig.String("cdnimg"); cdnimg != "" { content.Find("img").Each(func(i int, contentSelection *goquery.Selection) { if src, ok := contentSelection.Attr("src"); ok && strings.HasPrefix(src, "/uploads/") { contentSelection.SetAttr("src", utils.JoinURI(cdnimg, src)) diff --git a/models/DocumentModel.go b/models/DocumentModel.go index c081cbd2..4de63889 100644 --- a/models/DocumentModel.go +++ b/models/DocumentModel.go @@ -12,9 +12,9 @@ import ( "strings" "github.com/PuerkitoBio/goquery" - "github.com/beego/beego/v2/adapter" "github.com/beego/beego/v2/client/orm" "github.com/beego/beego/v2/core/logs" + "github.com/beego/beego/v2/server/web" "github.com/mindoc-org/mindoc/cache" "github.com/mindoc-org/mindoc/conf" "github.com/mindoc-org/mindoc/utils" @@ -330,7 +330,7 @@ func (item *Document) Processor() *Document { selector.First().AppendHtml(release) } } - cdnimg := adapter.AppConfig.String("cdnimg") + cdnimg,_ := web.AppConfig.String("cdnimg") docQuery.Find("img").Each(func(i int, selection *goquery.Selection) { diff --git a/models/Member.go b/models/Member.go index 0fc3c2d0..9642823b 100644 --- a/models/Member.go +++ b/models/Member.go @@ -19,9 +19,9 @@ import ( "math" - "github.com/beego/beego/v2/adapter" "github.com/beego/beego/v2/adapter/orm" "github.com/beego/beego/v2/core/logs" + "github.com/beego/beego/v2/server/web" "github.com/mindoc-org/mindoc/conf" "github.com/mindoc-org/mindoc/utils" ) @@ -74,10 +74,10 @@ func (m *Member) Login(account string, password string) (*Member, error) { err := o.Raw("select * from md_members where (account = ? or email = ?) and status = 0 limit 1;", account, account).QueryRow(member) if err != nil { - if adapter.AppConfig.DefaultBool("ldap_enable", false) == true { + if web.AppConfig.DefaultBool("ldap_enable", false) { logs.Info("转入LDAP登陆 ->", account) return member.ldapLogin(account, password) - } else if adapter.AppConfig.String("http_login_url") != "" { + } else if url, err := web.AppConfig.String("http_login_url"); url != "" { logs.Info("转入 HTTP 接口登陆 ->", account) return member.httpLogin(account, password) } else { @@ -107,26 +107,32 @@ func (m *Member) Login(account string, password string) (*Member, error) { //ldapLogin 通过LDAP登陆 func (m *Member) ldapLogin(account string, password string) (*Member, error) { - if adapter.AppConfig.DefaultBool("ldap_enable", false) == false { + if web.AppConfig.DefaultBool("ldap_enable", false) { return m, ErrMemberAuthMethodInvalid } var err error - lc, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", adapter.AppConfig.String("ldap_host"), adapter.AppConfig.DefaultInt("ldap_port", 3268))) + ldaphost, _ := web.AppConfig.String("ldap_host") + lc, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", ldaphost, web.AppConfig.DefaultInt("ldap_port", 3268))) if err != nil { logs.Error("绑定 LDAP 用户失败 ->", err) return m, ErrLDAPConnect } defer lc.Close() - err = lc.Bind(adapter.AppConfig.String("ldap_user"), adapter.AppConfig.String("ldap_password")) + ldapuser, _ := web.AppConfig.String("ldap_user") + ldappass, _ := web.AppConfig.String("ldap_password") + err = lc.Bind(ldapuser, ldappass) if err != nil { logs.Error("绑定 LDAP 用户失败 ->", err) return m, ErrLDAPFirstBind } + ldapbase, _ := web.AppConfig.String("ldap_base") + ldapfilter, _ := web.AppConfig.String("ldap_filter") + ldapattr, _ := web.AppConfig.String("ldap_attribute") searchRequest := ldap.NewSearchRequest( - adapter.AppConfig.String("ldap_base"), + ldapbase, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, //修改objectClass通过配置文件获取值 - fmt.Sprintf("(&(%s)(%s=%s))", adapter.AppConfig.String("ldap_filter"), adapter.AppConfig.String("ldap_attribute"), account), + fmt.Sprintf("(&(%s)(%s=%s))", ldapfilter, ldapattr, account), []string{"dn", "mail"}, nil, ) @@ -149,7 +155,7 @@ func (m *Member) ldapLogin(account string, password string) (*Member, error) { m.Email = searchResult.Entries[0].GetAttributeValue("mail") m.AuthMethod = "ldap" m.Avatar = "/static/images/headimgurl.jpg" - m.Role = conf.SystemRole(adapter.AppConfig.DefaultInt("ldap_user_role", 2)) + m.Role = conf.SystemRole(web.AppConfig.DefaultInt("ldap_user_role", 2)) m.CreateTime = time.Now() err = m.Add() @@ -163,7 +169,7 @@ func (m *Member) ldapLogin(account string, password string) (*Member, error) { } func (m *Member) httpLogin(account, password string) (*Member, error) { - urlStr := adapter.AppConfig.String("http_login_url") + urlStr,_ := web.AppConfig.String("http_login_url") if urlStr == "" { return nil, ErrMemberAuthMethodInvalid } @@ -174,7 +180,7 @@ func (m *Member) httpLogin(account, password string) (*Member, error) { "time": []string{strconv.FormatInt(time.Now().Unix(), 10)}, } h := md5.New() - h.Write([]byte(val.Encode() + adapter.AppConfig.DefaultString("http_login_secret", ""))) + h.Write([]byte(val.Encode() + web.AppConfig.DefaultString("http_login_secret", ""))) val.Add("sn", hex.EncodeToString(h.Sum(nil))) @@ -226,7 +232,7 @@ func (m *Member) httpLogin(account, password string) (*Member, error) { member.Account = account member.Password = password member.AuthMethod = "http" - member.Role = conf.SystemRole(adapter.AppConfig.DefaultInt("ldap_user_role", 2)) + member.Role = conf.SystemRole(web.AppConfig.DefaultInt("ldap_user_role", 2)) member.CreateTime = time.Now() if err := member.Add(); err != nil { logs.Error("自动注册用户错误", err) diff --git a/routers/filter.go b/routers/filter.go index 67cc5f2f..98dacb1d 100644 --- a/routers/filter.go +++ b/routers/filter.go @@ -6,8 +6,8 @@ import ( "regexp" "github.com/beego/beego/v2/adapter" - "github.com/beego/beego/v2/server/web" "github.com/beego/beego/v2/adapter/context" + "github.com/beego/beego/v2/server/web" "github.com/mindoc-org/mindoc/conf" "github.com/mindoc-org/mindoc/models" ) @@ -47,7 +47,8 @@ func init() { } var StartRouter = func(ctx *context.Context) { - sessionId := ctx.Input.Cookie(adapter.AppConfig.String("sessionname")) + sessname, _ := web.AppConfig.String("sessionname") + sessionId := ctx.Input.Cookie(sessname) if sessionId != "" { //sessionId必须是数字字母组成,且最小32个字符,最大1024字符 if ok, err := regexp.MatchString(`^[a-zA-Z0-9]{32,512}$`, sessionId); !ok || err != nil {