Skip to content

Commit

Permalink
fix archive root determination logic (#2)
Browse files Browse the repository at this point in the history
Also add integration tests
  • Loading branch information
femnad authored Dec 9, 2023
1 parent 43eb033 commit 43cabf8
Show file tree
Hide file tree
Showing 8 changed files with 439 additions and 53 deletions.
13 changes: 12 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
on:
pull_request:
types:
- opened
push:
branches:
- main
workflow_dispatch:
inputs:

jobs:
test:
Expand All @@ -10,4 +15,10 @@ jobs:
steps:
- uses: actions/checkout@v4

- run: go test ./...
- name: Run Go tests
run: go test ./...

- name: Run integration tests
run: |
go install
python3 tests/test_archives.py
11 changes: 11 additions & 0 deletions base/settings/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@ import (
"github.com/femnad/fup/internal"
)

const defaultBinPath = "~/bin"

type FactMap map[string]map[string]string

type Settings struct {
BinDir string `yaml:"bin_dir"`
CloneDir string `yaml:"clone_dir"`
EnsureEnv map[string]string `yaml:"ensure_env"`
EnsurePaths []string `yaml:"ensure_paths"`
Expand All @@ -24,6 +27,14 @@ type Settings struct {
VirtualEnvDir string `yaml:"virtualenv_dir"`
}

func (s Settings) GetBinPath() string {
if s.BinDir != "" {
return s.BinDir
}

return defaultBinPath
}

func expand(s string, lookup map[string]string) string {
var cur bytes.Buffer
var out bytes.Buffer
Expand Down
139 changes: 91 additions & 48 deletions provision/archive.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,35 +14,37 @@ import (
"regexp"
"strings"

mapset "github.com/deckarep/golang-set/v2"
"github.com/xi2/xz"

marecmd "github.com/femnad/mare/cmd"

"github.com/femnad/fup/base"
"github.com/femnad/fup/base/settings"
"github.com/femnad/fup/internal"
"github.com/femnad/fup/precheck/unless"
"github.com/femnad/fup/precheck/when"
"github.com/femnad/fup/remote"
"github.com/femnad/mare"
marecmd "github.com/femnad/mare/cmd"
)

const (
bufferSize = 8192
dirMode = 0755
xzDictMax = 1 << 27
tarFileRegex = `\.tar(\.(gz|bz2|xz))$`
tarFileRegex = `\.tar(\.(gz|bz2|xz))?$`
)

type archiveInfo struct {
hasRootDir bool
target string
}

type archiveEntry struct {
info os.FileInfo
name string
}

// archiveRoot stores an archive's root dir and species if the root dir is part of the archive files.
type archiveRoot struct {
hasRootDir bool
target string
}

func processDownload(archive base.Archive, s settings.Settings) (string, error) {
url := archive.ExpandURL(s)
if url == "" {
Expand All @@ -66,7 +68,7 @@ func processDownload(archive base.Archive, s settings.Settings) (string, error)
return "", err
}

return extractFn(response, dirName)
return extractFn(archive, response, dirName)
}

func getReader(response remote.Response, tempFile *os.File) (io.Reader, error) {
Expand Down Expand Up @@ -191,8 +193,8 @@ func downloadTempFile(response remote.Response) (string, error) {
return tempFilePath, nil
}

func isExec(info os.FileInfo) bool {
return info.Mode().Perm()&0100 != 0
func isExecutableFile(info os.FileInfo) bool {
return !info.IsDir() && info.Mode().Perm()&0100 != 0
}

func getFilename(response remote.Response) string {
Expand All @@ -203,43 +205,75 @@ func getFilename(response remote.Response) string {
return filename
}

func getArchiveInfo(target string) archiveInfo {
if strings.Index(target, "/") >= 0 {
firstSlash := strings.Index(target, "/")
target = target[:firstSlash]
func commonPrefix(names []string) string {
if len(names) == 0 {
return ""
}

first := names[0]
minLength := len(first)
for _, name := range names {
if len(name) < minLength {
minLength = len(name)
}
}

for i := 0; i < minLength; i++ {
currentChar := first[i]
for _, name := range names {
if name[i] != currentChar {
return first[:i]
}
}
}
return archiveInfo{target: target, hasRootDir: true}

return first[:minLength]
}

func inspectEntries(entries []archiveEntry, url string) (archiveInfo, error) {
var execs []string
func determineArchiveRoot(archive base.Archive, entries []archiveEntry) (archiveRoot, error) {
names := mare.Map(entries, func(entry archiveEntry) string {
return entry.name
})
prefix := commonPrefix(names)
roots := mapset.NewSet[string]()
var execs []archiveEntry
for _, entry := range entries {
info := entry.info
name := entry.name
if info.IsDir() {
return getArchiveInfo(entry.name), nil
} else if isExec(info) {
execs = append(execs, name)
rootDir := strings.Split(entry.name, "/")
roots.Add(rootDir[0])
if isExecutableFile(entry.info) {
execs = append(execs, entry)
}
}

if len(execs) == 1 {
return getArchiveInfo(execs[0]), nil
var hasRootDir bool
var target string
if roots.Cardinality() == 1 {
root, ok := roots.Pop()
if !ok {
return archiveRoot{}, fmt.Errorf("error determining root dir for %s", archive.Url)
}

hasRootDir = strings.Index(prefix, "/") > -1
target = root
} else if archive.Name() != "" {
target = archive.Name()
} else {
target = execs[0].name
}

return archiveInfo{}, fmt.Errorf("unable to determine root for archive: %s", url)
return archiveRoot{hasRootDir: hasRootDir, target: target}, nil
}

func getTarInfo(tempfile string, response remote.Response) (archiveInfo, error) {
func getTarInfo(archive base.Archive, response remote.Response, tempfile string) (archiveRoot, error) {
f, err := os.Open(tempfile)
if err != nil {
return archiveInfo{}, err
return archiveRoot{}, err
}
defer f.Close()

reader, err := getReader(response, f)
if err != nil {
return archiveInfo{}, err
return archiveRoot{}, err
}

var entries []archiveEntry
Expand All @@ -258,25 +292,34 @@ func getTarInfo(tempfile string, response remote.Response) (archiveInfo, error)
})
}

return inspectEntries(entries, response.URL)
return determineArchiveRoot(archive, entries)
}

func getOutputPath(info archiveInfo, fileName, dirName string) string {
if info.hasRootDir {
func getOutputPath(archiveRoot archiveRoot, fileName, dirName string) string {
if archiveRoot.hasRootDir {
return filepath.Join(dirName, fileName)
}

return filepath.Join(dirName, info.target, fileName)
return filepath.Join(dirName, archiveRoot.target, fileName)
}

func getAbsTarget(dirName string, root archiveRoot) (string, error) {
wd, err := os.Getwd()
if err != nil {
return "", err
}

return path.Join(wd, dirName, root.target), nil
}

// Shamelessly lifted from https://golangdocs.com/tar-gzip-in-golang
func untar(response remote.Response, dirName string) (string, error) {
func untar(archive base.Archive, response remote.Response, dirName string) (string, error) {
tempfile, err := downloadTempFile(response)
if err != nil {
return "", err
}

aInfo, err := getTarInfo(tempfile, response)
root, err := getTarInfo(archive, response, tempfile)

f, err := os.Open(tempfile)
if err != nil {
Expand All @@ -298,7 +341,7 @@ func untar(response remote.Response, dirName string) (string, error) {
panic(err)
}

outputPath := getOutputPath(aInfo, header.Name, dirName)
outputPath := getOutputPath(root, header.Name, dirName)
err = extractCompressedFile(header.FileInfo(), outputPath, tarReader)
if err != nil {
return "", err
Expand All @@ -310,13 +353,13 @@ func untar(response remote.Response, dirName string) (string, error) {
return "", err
}

return aInfo.target, nil
return getAbsTarget(dirName, root)
}

func getZipInfo(tempFile string, response remote.Response) (archiveInfo, error) {
func getZipInfo(archive base.Archive, tempFile string) (archiveRoot, error) {
zipArchive, err := zip.OpenReader(tempFile)
if err != nil {
return archiveInfo{}, err
return archiveRoot{}, err
}

var entries []archiveEntry
Expand All @@ -329,16 +372,16 @@ func getZipInfo(tempFile string, response remote.Response) (archiveInfo, error)

err = zipArchive.Close()
if err != nil {
return archiveInfo{}, err
return archiveRoot{}, err
}

return inspectEntries(entries, response.URL)
return determineArchiveRoot(archive, entries)
}

func unzip(response remote.Response, dirName string) (string, error) {
func unzip(archive base.Archive, response remote.Response, dirName string) (string, error) {
tempFile, err := downloadTempFile(response)

aInfo, err := getZipInfo(tempFile, response)
root, err := getZipInfo(archive, tempFile)
if err != nil {
return "", err
}
Expand All @@ -349,7 +392,7 @@ func unzip(response remote.Response, dirName string) (string, error) {
}

for _, f := range zipArchive.File {
output := getOutputPath(aInfo, f.Name, dirName)
output := getOutputPath(root, f.Name, dirName)
err = unzipFile(f, output)
if err != nil {
return "", err
Expand All @@ -361,10 +404,10 @@ func unzip(response remote.Response, dirName string) (string, error) {
return "", err
}

return aInfo.target, nil
return getAbsTarget(dirName, root)
}

func getExtractionFn(archive base.Archive, s settings.Settings, contentDisposition string) (func(remote.Response, string) (string, error), error) {
func getExtractionFn(archive base.Archive, s settings.Settings, contentDisposition string) (func(base.Archive, remote.Response, string) (string, error), error) {
fileName := archive.ExpandURL(s)
if contentDisposition != "" {
fileName = contentDisposition
Expand Down Expand Up @@ -443,7 +486,7 @@ func extractArchive(archive base.Archive, s settings.Settings) {
}

for _, symlink := range archive.ExpandSymlinks(s, target) {
err = createSymlink(symlink, path.Join(s.ExtractDir, target))
err = createSymlink(symlink, target, s.BinDir)
if err != nil {
internal.Log.Errorf("error creating symlink for archive %s: %v", url, err)
return
Expand Down
Loading

0 comments on commit 43cabf8

Please sign in to comment.