diff --git a/commands/commands.go b/commands/commands.go index 5d61e2e..93dfc4f 100644 --- a/commands/commands.go +++ b/commands/commands.go @@ -17,6 +17,7 @@ import ( "io/ioutil" "log" "net/http" + "net/url" "github.com/jollheef/wi/storage" @@ -25,15 +26,15 @@ import ( "golang.org/x/net/html/charset" ) -func parseLink(db *sql.DB, oldPage, value string, req *http.Request) (htmlPage string, err error) { - url, err := req.URL.Parse(value) +func parseLink(db *sql.DB, oldPage, value string, lastUrl *url.URL) (htmlPage string, err error) { + linkUrl, err := lastUrl.Parse(value) if err != nil { return } - linkNo, err := storage.GetLinkID(db, url.String()) + linkNo, err := storage.GetLinkID(db, linkUrl.String()) if err != nil { - linkNo, err = storage.AddLink(db, url.String()) + linkNo, err = storage.AddLink(db, linkUrl.String()) if err != nil { return } @@ -49,7 +50,7 @@ func parseLink(db *sql.DB, oldPage, value string, req *http.Request) (htmlPage s return } -func parseLinks(db *sql.DB, body []byte, req *http.Request) (htmlPage string, err error) { +func parseLinks(db *sql.DB, body []byte, lastUrl *url.URL) (htmlPage string, err error) { htmlPage = string(body) z := html.NewTokenizer(bytes.NewReader(body)) @@ -64,7 +65,7 @@ func parseLinks(db *sql.DB, body []byte, req *http.Request) (htmlPage string, er key, value, moreAttr := z.TagAttr() if string(key) == "href" { - htmlPage, err = parseLink(db, htmlPage, string(value), req) + htmlPage, err = parseLink(db, htmlPage, string(value), lastUrl) if err != nil { return } @@ -79,15 +80,22 @@ func parseLinks(db *sql.DB, body []byte, req *http.Request) (htmlPage string, er return } -func Get(db *sql.DB, url string) { +func Get(db *sql.DB, linkUrl string) { client := &http.Client{} - if !strings.Contains(url, "://") { - url = "https://" + url + var lastUrl *url.URL + + client.CheckRedirect = func(r *http.Request, via []*http.Request) (err error) { + lastUrl = r.URL + return + } + + if !strings.Contains(linkUrl, "://") { + linkUrl = "https://" + linkUrl } // TODO Full url encoding - req, err := http.NewRequest("GET", strings.Replace(url, " ", "%20", -1), nil) + req, err := http.NewRequest("GET", strings.Replace(linkUrl, " ", "%20", -1), nil) if err != nil { log.Fatalln(err) } @@ -99,7 +107,11 @@ func Get(db *sql.DB, url string) { log.Fatalln(err) } - storage.AddHistoryURL(db, url) + if lastUrl == nil { + lastUrl = req.URL + } + + storage.AddHistoryURL(db, linkUrl) defer resp.Body.Close() @@ -113,7 +125,7 @@ func Get(db *sql.DB, url string) { log.Fatalln("IO error:", err) } - htmlPage, err := parseLinks(db, body, req) + htmlPage, err := parseLinks(db, body, lastUrl) if err != nil { log.Fatalln("Parse links error:", err) } @@ -129,20 +141,20 @@ func Get(db *sql.DB, url string) { func Link(db *sql.DB, linkID int64, fromHistory bool) { - var url string + var linkUrl string var err error if fromHistory { - url, err = storage.GetHistoryUrl(db, linkID) + linkUrl, err = storage.GetHistoryUrl(db, linkID) } else { - url, err = storage.GetLink(db, linkID) + linkUrl, err = storage.GetLink(db, linkID) } if err != nil { log.Fatalln("Get link/history url error:", err) } - Get(db, url) + Get(db, linkUrl) } func History(db *sql.DB, argAmount, defaultAmount int64, all bool) {