//  Copyright (c) 2020 Cisco and/or its affiliates.
//
//  Licensed under the Apache License, Version 2.0 (the "License");
//  you may not use this file except in compliance with the License.
//  You may obtain a copy of the License at:
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
//  Unless required by applicable law or agreed to in writing, software
//  distributed under the License is distributed on an "AS IS" BASIS,
//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//  See the License for the specific language governing permissions and
//  limitations under the License.

package binapigen

import (
	"bufio"
	"bytes"
	"fmt"
	"go/ast"
	"go/parser"
	"go/printer"
	"go/token"
	"io/ioutil"
	"os"
	"path"
	"path/filepath"
	"sort"
	"strconv"
	"strings"

	"github.com/sirupsen/logrus"

	"git.fd.io/govpp.git/binapigen/vppapi"
)

type Generator struct {
	Files       []*File
	FilesByName map[string]*File
	FilesByPath map[string]*File

	opts       Options
	apifiles   []*vppapi.File
	vppVersion string

	filesToGen []string
	genfiles   []*GenFile

	enumsByName    map[string]*Enum
	aliasesByName  map[string]*Alias
	structsByName  map[string]*Struct
	unionsByName   map[string]*Union
	messagesByName map[string]*Message
}

func New(opts Options, apiFiles []*vppapi.File, filesToGen []string) (*Generator, error) {
	gen := &Generator{
		FilesByName:    make(map[string]*File),
		FilesByPath:    make(map[string]*File),
		opts:           opts,
		apifiles:       apiFiles,
		filesToGen:     filesToGen,
		enumsByName:    map[string]*Enum{},
		aliasesByName:  map[string]*Alias{},
		structsByName:  map[string]*Struct{},
		unionsByName:   map[string]*Union{},
		messagesByName: map[string]*Message{},
	}

	// Normalize API files
	SortFilesByImports(gen.apifiles)
	for _, apiFile := range apiFiles {
		RemoveImportedTypes(gen.apifiles, apiFile)
		SortFileObjectsByName(apiFile)
	}

	// prepare package names and import paths
	packageNames := make(map[string]GoPackageName)
	importPaths := make(map[string]GoImportPath)
	for _, apifile := range gen.apifiles {
		filename := getFilename(apifile)
		packageNames[filename] = cleanPackageName(apifile.Name)
		importPaths[filename] = GoImportPath(path.Join(gen.opts.ImportPrefix, baseName(apifile.Name)))
	}

	logrus.Debugf("adding %d VPP API files to generator", len(gen.apifiles))

	for _, apifile := range gen.apifiles {
		if _, ok := gen.FilesByName[apifile.Name]; ok {
			return nil, fmt.Errorf("duplicate file: %q", apifile.Name)
		}

		filename := getFilename(apifile)
		file, err := newFile(gen, apifile, packageNames[filename], importPaths[filename])
		if err != nil {
			return nil, fmt.Errorf("loading file %s failed: %w", apifile.Name, err)
		}
		gen.Files = append(gen.Files, file)
		gen.FilesByName[apifile.Name] = file
		gen.FilesByPath[apifile.Path] = file

		logrus.Debugf("added file %q (path: %v)", apifile.Name, apifile.Path)
	}

	// mark files for generation
	if len(gen.filesToGen) > 0 {
		logrus.Debugf("Checking %d files to generate: %v", len(gen.filesToGen), gen.filesToGen)
		for _, genFile := range gen.filesToGen {
			markGen := func(file *File) {
				file.Generate = true
				// generate all imported files
				for _, impFile := range file.importedFiles(gen) {
					impFile.Generate = true
				}
			}
			if file, ok := gen.FilesByName[genFile]; ok {
				markGen(file)
				continue
			}
			logrus.Debugf("File %s was not found by name", genFile)
			if file, ok := gen.FilesByPath[genFile]; ok {
				markGen(file)
				continue
			}
			return nil, fmt.Errorf("no API file found for: %v", genFile)
		}
	} else {
		logrus.Debugf("Files to generate not specified, marking all %d files for generate", len(gen.Files))
		for _, file := range gen.Files {
			file.Generate = true
		}
	}

	return gen, nil
}

func getFilename(file *vppapi.File) string {
	if file.Path == "" {
		return file.Name
	}
	return file.Path
}

func (g *Generator) Generate() error {
	if len(g.genfiles) == 0 {
		return fmt.Errorf("no files to generate")
	}

	logrus.Infof("Generating %d files", len(g.genfiles))

	for _, genfile := range g.genfiles {
		content, err := genfile.Content()
		if err != nil {
			return err
		}
		if err := writeSourceTo(genfile.filename, content); err != nil {
			return fmt.Errorf("writing source package %s failed: %v", genfile.filename, err)
		}
	}
	return nil
}

type GenFile struct {
	gen           *Generator
	file          *File
	filename      string
	goImportPath  GoImportPath
	buf           bytes.Buffer
	manualImports map[GoImportPath]bool
	packageNames  map[GoImportPath]GoPackageName
}

// NewGenFile creates new generated file with
func (g *Generator) NewGenFile(filename string, importPath GoImportPath) *GenFile {
	f := &GenFile{
		gen:           g,
		filename:      filename,
		goImportPath:  importPath,
		manualImports: make(map[GoImportPath]bool),
		packageNames:  make(map[GoImportPath]GoPackageName),
	}
	g.genfiles = append(g.genfiles, f)
	return f
}

func (g *GenFile) Write(p []byte) (n int, err error) {
	return g.buf.Write(p)
}

func (g *GenFile) Import(importPath GoImportPath) {
	g.manualImports[importPath] = true
}

func (g *GenFile) GoIdent(ident GoIdent) string {
	if ident.GoImportPath == g.goImportPath {
		return ident.GoName
	}
	if packageName, ok := g.packageNames[ident.GoImportPath]; ok {
		return string(packageName) + "." + ident.GoName
	}
	packageName := cleanPackageName(baseName(string(ident.GoImportPath)))
	g.packageNames[ident.GoImportPath] = packageName
	return string(packageName) + "." + ident.GoName
}

func (g *GenFile) P(v ...interface{}) {
	for _, x := range v {
		switch x := x.(type) {
		case GoIdent:
			fmt.Fprint(&g.buf, g.GoIdent(x))
		default:
			fmt.Fprint(&g.buf, x)
		}
	}
	fmt.Fprintln(&g.buf)
}

func (g *GenFile) Content() ([]byte, error) {
	if !strings.HasSuffix(g.filename, ".go") {
		return g.buf.Bytes(), nil
	}
	return g.injectImports(g.buf.Bytes())
}

func getImportClass(importPath string) int {
	if !strings.Contains(importPath, ".") {
		return 0 /* std */
	}
	return 1 /* External */
}

// injectImports parses source, injects import block declaration with all imports and return formatted
func (g *GenFile) injectImports(original []byte) ([]byte, error) {
	// Parse source code
	fset := token.NewFileSet()
	file, err := parser.ParseFile(fset, "", original, parser.ParseComments)
	if err != nil {
		var src bytes.Buffer
		s := bufio.NewScanner(bytes.NewReader(original))
		for line := 1; s.Scan(); line++ {
			fmt.Fprintf(&src, "%5d\t%s\n", line, s.Bytes())
		}
		return nil, fmt.Errorf("%v: unparsable Go source: %v\n%v", g.filename, err, src.String())
	}
	type Import struct {
		Name string
		Path string
	}
	// Prepare list of all imports
	var importPaths []Import
	for importPath := range g.packageNames {
		importPaths = append(importPaths, Import{
			Name: string(g.packageNames[importPath]),
			Path: string(importPath),
		})
	}
	for importPath := range g.manualImports {
		if _, ok := g.packageNames[importPath]; ok {
			continue
		}
		importPaths = append(importPaths, Import{
			Name: "_",
			Path: string(importPath),
		})
	}
	// Sort imports by import path
	sort.Slice(importPaths, func(i, j int) bool {
		ci := getImportClass(importPaths[i].Path)
		cj := getImportClass(importPaths[j].Path)
		if ci == cj {
			return importPaths[i].Path < importPaths[j].Path
		}
		return ci < cj
	})
	// Inject new import block into parsed AST
	if len(importPaths) > 0 {
		// Find import block position
		pos := file.Package
		tokFile := fset.File(file.Package)
		pkgLine := tokFile.Line(file.Package)
		for _, c := range file.Comments {
			if tokFile.Line(c.Pos()) > pkgLine {
				break
			}
			pos = c.End()
		}
		// Prepare the import block
		impDecl := &ast.GenDecl{Tok: token.IMPORT, TokPos: pos, Lparen: pos, Rparen: pos}
		for i, importPath := range importPaths {
			var name *ast.Ident
			if importPath.Name == "_" || strings.Contains(importPath.Path, ".") {
				name = &ast.Ident{Name: importPath.Name, NamePos: pos}
			}
			value := strconv.Quote(importPath.Path)
			if i < len(importPaths)-1 {
				if getImportClass(importPath.Path) != getImportClass(importPaths[i+1].Path) {
					value += "\n"
				}
			}
			impDecl.Specs = append(impDecl.Specs, &ast.ImportSpec{
				Name:   name,
				Path:   &ast.BasicLit{Kind: token.STRING, Value: value, ValuePos: pos},
				EndPos: pos,
			})
		}

		file.Decls = append([]ast.Decl{impDecl}, file.Decls...)
	}
	// Reformat source code
	var out bytes.Buffer
	cfg := &printer.Config{
		Mode:     printer.TabIndent | printer.UseSpaces,
		Tabwidth: 8,
	}
	if err = cfg.Fprint(&out, fset, file); err != nil {
		return nil, fmt.Errorf("%v: can not reformat Go source: %v", g.filename, err)
	}
	return out.Bytes(), nil
}

func writeSourceTo(outputFile string, b []byte) error {
	// create output directory
	packageDir := filepath.Dir(outputFile)
	if err := os.MkdirAll(packageDir, 0775); err != nil {
		return fmt.Errorf("creating output dir %s failed: %v", packageDir, err)
	}

	// write generated code to output file
	if err := ioutil.WriteFile(outputFile, b, 0666); err != nil {
		return fmt.Errorf("writing to output file %s failed: %v", outputFile, err)
	}

	lines := bytes.Count(b, []byte("\n"))
	logf("wrote %d lines (%d bytes) to: %q", lines, len(b), outputFile)

	return nil
}