diff --git a/CMakeLists.txt b/CMakeLists.txt index 83b798142..2c38064a5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -160,7 +160,7 @@ check_toolchain() # NB: currently, ANTLR is used in dsl examples only, # however, there is a plan to use in the frontend, # so it is kept in the top-level cmake -if(BUDDY_DSL_EXAMPLES) +if(BUDDY_DSL_EXAMPLES OR FeGen) list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/Antlr) # required if linking to static library diff --git a/examples/FrontendGen/.gitignore b/examples/FrontendGen/.gitignore index 34ec8116e..72d477002 100644 --- a/examples/FrontendGen/.gitignore +++ b/examples/FrontendGen/.gitignore @@ -1,3 +1,4 @@ -Toy.g4 -MLIRToyVisitor.h - +test/ +*.g4 +*.td +*.cpp \ No newline at end of file diff --git a/examples/FrontendGen/example.fegen b/examples/FrontendGen/example.fegen index cc453795c..55caa409d 100644 --- a/examples/FrontendGen/example.fegen +++ b/examples/FrontendGen/example.fegen @@ -1,141 +1,106 @@ -dialect Toy_Dialect - : name = "toy" - : cppNamespace = "mlir::toy" - ; - -op ConstantOp - : arguments = (ins F64ElementsAttr : $value) - : results = (outs F64Tensor) - : builders = [ - OpBuilder<(ins "DenseElementsAttr" : $value), - [{ build($_builder, $_state, value.getType(), value); }]>, - OpBuilder<(ins "double":$value)>] - ; - -op AddOp - : arguments = (ins F64Tensor : $lhs, F64Tensor: $rhs) - : results = (outs F64Tensor) - : builders = [OpBuilder<(ins "Value" : $lhs, "Value" : $rhs)>] - ; - -op CastOp - : arguments = (ins F64Tensor:$input) - : results = (outs F64Tensor:$output) - ; - -op FuncOp - : arguments = (ins - SymbolNameAttr:$sym_name, - TypeAttrOf:$function_type - ) - : builders = [ OpBuilder<(ins - "StringRef":$name, "FunctionType":$type, - CArg<"ArrayRef", "{}">:$attrs)> - ] - ; - -op MulOp - : arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs) - : results = (outs F64Tensor) - : builders = [ - OpBuilder<(ins "Value":$lhs, "Value":$rhs)> - ] - ; - -op PrintOp - : arguments = (ins AnyTypeOf<[F64Tensor, F64MemRef]>:$input) - ; - -op ReshapeOp - : arguments = (ins F64Tensor : $input) - : results = (outs StaticShapeTensorOf<[F64]>) - ; - -op ReturnOp - : arguments = (ins Variadic:$input) - : builders = [ - OpBuilder<(ins), [{ build($_builder, $_state, std::nullopt); }]> - ] - ; - -op GenericCallOp - : arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$inputs) - : results = (outs F64Tensor) - : builders = [ - OpBuilder<(ins "StringRef":$callee, "ArrayRef":$arguments)> - ] - ; - -op TransposeOp - : arguments = (ins F64Tensor:$input) - : results = (outs F64Tensor) - : builders = [ - OpBuilder<(ins "Value":$input)> - ] - ; - - -rule module - : funDefine - ; - -rule expression - : Number - : tensorLiteral - : identifierExpr - : expression Add expression - ; - -rule returnExpr - : Return expression? - ; - -rule identifierExpr - : Identifier - : Identifier ParentheseOpen (expression (Comma expression) *)? ParentheseClose { - builder = GenericCallOp_1, PrintOp_0 - } - ; - -rule tensorLiteral - : SbracketOpen ( tensorLiteral ( Comma tensorLiteral ) *) ? SbracketClose - : Number - ; - -rule varDecl - : Var Identifier (type) ? (Equal expression) ? { - builder = ReshapeOp_0 - } - ; - -rule type - : AngleBracketOpen Number(Comma Number) * AngleBracketClose - ; - -rule funDefine - : prototype block { - builder = ReturnOp_1 - } - ; - -rule prototype - : Def Identifier ParentheseOpen declList ? ParentheseClose { - builder = FuncOp_0 - } - ; - -rule declList - : Identifier - : Identifier Comma declList - ; - -rule block - : BracketOpen(blockExpr Semi) * BracketClose - ; - -rule blockExpr - : varDecl - : returnExpr - : expression - ; - +fegen toy + +typedef struct { + parameters [list elementTypes] // ArrayParameter<'Type'> +} + +Type Toy_Type = any<[Tensor, struct]>; + +opdef constant { + arguments [operand list> numberAttr] // Variadic + results [operand Tensor res] + body { + list shape = shapeOf(res); + // full是一个内置函数,创建memref,并将每个元素都填充numberAttr + res = full(shape, numberAttr); + } +} + +opdef add { + arguments [operand Tensor lhs, operand Tensor rhs] + results [operand Tensor res] + body { + // 这个'+'也是一个内置的函数 + res = lhs + rhs; // res = builder.create(lsh, rhs); + } +} + +opdef mul { + arguments [operand Tensor lhs, operand Tensor rhs] + results [operand Tensor res] + body { + // 这个'*'也是一个内置的函数 + res = lhs * rhs; + } +} + +opdef reshape { + arguments [operand F64Tensor input] + results [operand F64Tensor output] + body { + list shape = shapeOf(output); + output = reshape(input, shape); + } +} + +double stod(string numStr){ + double res = 0; + int index; + int i; + for(i = 0; i <= len(numStr)-1; i=i+1){ + char c = numStr[0]; + int charNum; + if(c == '0'){ + charNum = 0; + }else if (c == '1'){ + charNum = 1; + }else if (c == '2'){ + charNum = 2; + }else if (c == '3'){ + charNum = 3; + }else if (c == '4'){ + charNum = 4; + }else if (c == '5'){ + charNum = 5; + }else if (c == '6'){ + charNum = 6; + }else if (c == '7'){ + charNum = 7; + }else if (c == '8'){ + charNum = 8; + }else if (c == '9'){ + charNum = 9; + }else if (c == '.'){ + index = i; + } + res = res * 10; + res = res + charNum; + } + res = res * 0.1**(len(numStr) - 1 - index); + return res; +} + + +module + : structDefine* funDefine+ + ; + +structDefine + : Struct Identifier BracketOpen (varDecl Semicolon)* BracketClose + ; + +// cpp value --get--> mlir::attribute || --constant Operation--> mlir::Value +// ======== || + +expression + : Number + { + returns [operand F64Tensor ret, operand F64Tensor ret] + actions { + // Type mlir::Value ret of operator | Attribute | Cpp Value + double numberAttr = stod($Number().getText()); + Type retType = Tensor<[], double>; + ret = constant(numberAttr, retType); + } + } + ; \ No newline at end of file diff --git a/examples/FrontendGen/function.fegen b/examples/FrontendGen/function.fegen new file mode 100644 index 000000000..1c3745e08 --- /dev/null +++ b/examples/FrontendGen/function.fegen @@ -0,0 +1,18 @@ +fegen toy + +double stod(string numStr){ + float res = 0.0; + int c = 1; + for(int i = 0; i < 3; i = i+1){ + if(c == 0){ + int charNum = 0; + int intNum = 1; + intNum = 1; + }else if (c == 1){ + int charNum = 1; + }else { + int charNum = 2; + } + } + return res; +} \ No newline at end of file diff --git a/examples/FrontendGen/makefile b/examples/FrontendGen/makefile index 26b604972..9fbf4d746 100644 --- a/examples/FrontendGen/makefile +++ b/examples/FrontendGen/makefile @@ -1,12 +1,17 @@ #!/bin/bash BUDDY_FRONTEND_GEN := ../../build/bin/buddy-frontendgen -frontendgen-emit-ast: - @${BUDDY_FRONTEND_GEN} -f ./example.fegen -emit=ast +opDefine: + @${BUDDY_FRONTEND_GEN} -f ./opDefine.fegen -frontendgen-emit-antlr: - @${BUDDY_FRONTEND_GEN} -f ./example.fegen -emit=antlr -g Toy +typeDefine: + @${BUDDY_FRONTEND_GEN} -f ./typeDefine.fegen -frontendgen-emit-visitor: - @${BUDDY_FRONTEND_GEN} -f ./example.fegen -emit=visitor -g Toy +rule: + @${BUDDY_FRONTEND_GEN} -f ./rule.fegen +function: + @${BUDDY_FRONTEND_GEN} -f ./function.fegen + +clean: + rm -f ./toy* \ No newline at end of file diff --git a/examples/FrontendGen/opDefine.fegen b/examples/FrontendGen/opDefine.fegen new file mode 100644 index 000000000..65d21e636 --- /dev/null +++ b/examples/FrontendGen/opDefine.fegen @@ -0,0 +1,14 @@ +fegen toy + +opdef add { + arguments [operand Integer lhs, operand Integer rhs] + results [operand Integer res] + body { + res = lhs + rhs; + } +} + +opdef constant { + arguments [attribute double value] + results [operand Tensor> res] +} \ No newline at end of file diff --git a/examples/FrontendGen/rule.fegen b/examples/FrontendGen/rule.fegen new file mode 100644 index 000000000..912d495c8 --- /dev/null +++ b/examples/FrontendGen/rule.fegen @@ -0,0 +1,14 @@ +fegen toy + +module + : structDefine* funDefine+ + ; + +structDefine + : Struct Identifier BracketOpen (varDecl Semicolon)* BracketClose + ; + +expression + : Number + | Identifier + ; \ No newline at end of file diff --git a/examples/FrontendGen/typeDefine.fegen b/examples/FrontendGen/typeDefine.fegen new file mode 100644 index 000000000..d4be6997d --- /dev/null +++ b/examples/FrontendGen/typeDefine.fegen @@ -0,0 +1,17 @@ +fegen toy + +typedef struct { + parameters [list elementTypes] +} + +typedef test1 { + parameters [Type e] +} + +typedef test2 { + parameters [list e] +} + +typedef test3 { + parameters [int e] +} \ No newline at end of file diff --git a/frontend/CMakeLists.txt b/frontend/CMakeLists.txt index 39e683c04..7a2d23547 100644 --- a/frontend/CMakeLists.txt +++ b/frontend/CMakeLists.txt @@ -1,4 +1,6 @@ -add_subdirectory(FrontendGen) +if(FeGen) + add_subdirectory(FrontendGen) +endif() add_subdirectory(Interfaces) if(BUDDY_MLIR_ENABLE_PYTHON_PACKAGES) add_subdirectory(Python) diff --git a/frontend/FrontendGen/.gitignore b/frontend/FrontendGen/.gitignore new file mode 100644 index 000000000..91827d60b --- /dev/null +++ b/frontend/FrontendGen/.gitignore @@ -0,0 +1 @@ +.antlr/ \ No newline at end of file diff --git a/frontend/FrontendGen/CMakeLists.txt b/frontend/FrontendGen/CMakeLists.txt index 0f6705193..8294448b4 100644 --- a/frontend/FrontendGen/CMakeLists.txt +++ b/frontend/FrontendGen/CMakeLists.txt @@ -1,10 +1,19 @@ -include_directories("${CMAKE_CURRENT_SOURCE_DIR}/include") -link_directories("${CMAKE_CURRENT_BINARY_DIR}/lib") add_subdirectory(lib) set (LLVM_LINK_COMPONENTS -support -frontendgenlib -) + support +) + +include_directories(${ANTLR_FegenLexer_OUTPUT_DIR}) +include_directories(${ANTLR_FegenParser_OUTPUT_DIR}) +include_directories("${CMAKE_CURRENT_SOURCE_DIR}/include") + add_llvm_tool(buddy-frontendgen -frontendgen.cpp + frontendgen.cpp +) + +target_link_libraries(buddy-frontendgen + PRIVATE + fegen_antlr_generated + fegenVisitor + antlr4_static ) diff --git a/frontend/FrontendGen/README.md b/frontend/FrontendGen/README.md new file mode 100644 index 000000000..13b5b2e69 --- /dev/null +++ b/frontend/FrontendGen/README.md @@ -0,0 +1,14 @@ +# How to build + +FrontendGen is designed for generate mlir project quickly by writing fegen files. + +The `FeGen` option needs to be enabled when building. + +``` bash +$ cmake -G Ninja .. \ + -DMLIR_DIR=$PWD/../llvm/build/lib/cmake/mlir \ + -DLLVM_DIR=$PWD/../llvm/build/lib/cmake/llvm \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DFeGen=ON \ + -DCMAKE_BUILD_TYPE=RELEASE +``` \ No newline at end of file diff --git a/frontend/FrontendGen/frontendgen.cpp b/frontend/FrontendGen/frontendgen.cpp index bfb6ef200..8490405f4 100644 --- a/frontend/FrontendGen/frontendgen.cpp +++ b/frontend/FrontendGen/frontendgen.cpp @@ -18,111 +18,55 @@ // //===----------------------------------------------------------------------===// -#include "CGModule.h" -#include "Diagnostics.h" -#include "Lexer.h" -#include "Parser.h" +#include + #include "llvm/Support/CommandLine.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/ToolOutputFile.h" #include "llvm/Support/raw_ostream.h" +#include "FegenLexer.h" +#include "FegenParser.h" +#include "FegenVisitor.h" +#include "antlr4-common.h" + llvm::cl::opt inputFileName("f", llvm::cl::desc("")); -llvm::cl::opt grammarName("g", llvm::cl::desc("")); namespace { enum Action { none, dumpAst, dumpAntlr, dumpAll, dumpVisitor }; } -llvm::cl::opt emitAction( - "emit", llvm::cl::desc("Select the kind of output desired"), - llvm::cl::values(clEnumValN(dumpAst, "ast", "Out put the ast")), - llvm::cl::values(clEnumValN(dumpAntlr, "antlr", "Out put the antlr file")), - llvm::cl::values(clEnumValN(dumpVisitor, "visitor", - "Out put the visitor file")), - llvm::cl::values(clEnumValN(dumpAll, "all", "put out all file"))); - -/// Control generation of ast, tablegen files and antlr files. -void emit(frontendgen::Module *module, frontendgen::Terminators &terminators) { - bool emitAst = emitAction == Action::dumpAst; - bool emitAntlr = - emitAction == Action::dumpAntlr || emitAction == Action::dumpAll; - bool emitVisitor = - emitAction == Action::dumpVisitor || emitAction == Action::dumpAll; - // Emit antlr file. - if (emitAntlr) { - if (grammarName.empty()) { - llvm::errs() << "if you want to emit g4 file you have to point out the " - "name of grammar.\n"; - return; - } - std::error_code EC; - llvm::sys::fs::OpenFlags openFlags = llvm::sys::fs::OpenFlags::OF_None; - std::string outputFileName = grammarName.c_str(); - outputFileName += ".g4"; - auto Out = llvm::ToolOutputFile(outputFileName, EC, openFlags); - frontendgen::CGModule CGmodule(module, Out.os(), terminators); - CGmodule.emitAntlr(grammarName); - Out.keep(); - } - // Emit antlr's AST. - if (emitAst && !module->getRules().empty()) { - llvm::raw_fd_ostream os(-1, true); - frontendgen::CGModule CGmodule(module, os, terminators); - CGmodule.emitAST(); - } - // Emit visitor file. - if (emitVisitor && !module->getRules().empty()) { - std::error_code EC; - llvm::sys::fs::OpenFlags openFlags = llvm::sys::fs::OpenFlags::OF_None; - std::string outputFileName("MLIR"); - outputFileName = outputFileName + grammarName + "Visitor.h"; - auto Out = llvm::ToolOutputFile(outputFileName, EC, openFlags); - frontendgen::CGModule CGmodule(module, Out.os(), terminators); - CGmodule.emitMLIRVisitor(grammarName); - Out.keep(); - } - // Free memory. - for (auto rule : module->getRules()) { - for (auto generatorsAndOthers : rule->getGeneratorsAndOthers()) { - for (auto element : generatorsAndOthers->getGenerator()) { - delete element; - } - delete generatorsAndOthers; - } - delete rule; - } +// llvm::cl::opt emitAction( +// "emit", llvm::cl::desc("Select the kind of output desired"), +// llvm::cl::values(clEnumValN(dumpAst, "ast", "Out put the ast")), +// llvm::cl::values(clEnumValN(dumpAntlr, "g4", "Out put the g4 file")), +// llvm::cl::values(clEnumValN(dumpVisitor, "visitor", +// "Out put the visitor file")), +// llvm::cl::values(clEnumValN(dumpAll, "all", "put out all file"))); - delete module->getDialect(); - for (auto op : module->getOps()) { - delete op->getArguments(); - delete op->getResults(); - for (auto builder : op->getBuilders()) { - delete builder->getDag(); - delete builder; - } - delete op; - } - delete module; +int dumpAST(fegen::FegenParser::FegenSpecContext *moduleAST) { + llvm::errs() << moduleAST->toStringTree(1 /* prety format*/) << "\n"; + return 0; } int main(int argc, char *argv[]) { llvm::cl::ParseCommandLineOptions(argc, argv); - llvm::ErrorOr> file = - llvm::MemoryBuffer::getFile(inputFileName.c_str()); - if (std::error_code bufferError = file.getError()) { - llvm::errs() << "error read: " << bufferError.message() << '\n'; - exit(1); - } - llvm::SourceMgr srcMgr; - srcMgr.AddNewSourceBuffer(std::move(*file), llvm::SMLoc()); - frontendgen::DiagnosticEngine diagnostic(srcMgr); - frontendgen::Lexer lexer(srcMgr, diagnostic); - frontendgen::Sema action; - frontendgen::Terminators terminators; - frontendgen::Parser parser(lexer, action, terminators); - frontendgen::Module *module = parser.parser(); - emit(module, terminators); + + // Parse the input file with ANTLR. + std::fstream in(inputFileName); + antlr4::ANTLRInputStream input(in); + fegen::FegenLexer lexer(&input); + antlr4::CommonTokenStream tokens(&lexer); + fegen::FegenParser parser(&tokens); + auto moduleAST = parser.fegenSpec(); + + fegen::FegenVisitor visitor; + visitor.visit(moduleAST); + visitor.emitG4(); + visitor.emitTypeDefination(); + visitor.emitDialectDefination(); + visitor.emitOpDefination(); + visitor.emitBuiltinFunction(moduleAST); return 0; } diff --git a/frontend/FrontendGen/include/AST.h b/frontend/FrontendGen/include/AST.h deleted file mode 100644 index 549f68f2a..000000000 --- a/frontend/FrontendGen/include/AST.h +++ /dev/null @@ -1,211 +0,0 @@ -//====- AST.h -------------------------------------------------------------===// -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// - -#ifndef INCLUDE_AST_H -#define INCLUDE_AST_H -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringMap.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/SMLoc.h" -#include -namespace frontendgen { - -/// Base class for all generator nodes. -class AntlrBase { -public: - enum baseKind { rule, terminator, pbexpression }; - -private: - baseKind kind; - -protected: - llvm::StringRef name; - llvm::SMLoc loc; - -public: - virtual ~AntlrBase(){}; - AntlrBase(llvm::StringRef name, llvm::SMLoc loc, baseKind kind) - : kind(kind), name(name), loc(loc) {} - llvm::StringRef getName() { return name; } - llvm::SMLoc getLoc() { return loc; } - baseKind getKind() const { return kind; } -}; - -class GeneratorAndOthers { - std::vector generator; - llvm::SmallVector builderNames; - llvm::SmallVector builderIdxs; - -public: - void setbuilderNames(llvm::SmallVector &builderNames) { - this->builderNames = builderNames; - } - void setbuilderIdxs(llvm::SmallVector &builderIdxs) { - this->builderIdxs = builderIdxs; - } - std::vector &getGenerator() { return generator; } - llvm::SmallVector getBuilderNames() { - return this->builderNames; - } - llvm::SmallVector getBuilderIndices() { return this->builderIdxs; } -}; - -/// This class is used to mark the node in the generator as a rule, and can also -/// store the generators of a rule. -class Rule : public AntlrBase { - std::vector generatorsAndOthers; - -public: - Rule(llvm::StringRef name, llvm::SMLoc loc, baseKind kind) - : AntlrBase(name, loc, kind) {} - static bool classof(const AntlrBase *base) { - return base->getKind() == baseKind::rule; - } - void setGenerators(std::vector &generatorsAndOthers) { - this->generatorsAndOthers = generatorsAndOthers; - } - std::vector getGeneratorsAndOthers() { - return generatorsAndOthers; - } -}; -/// The class is used to mark the node in the generator as a terminator. -class Terminator : public AntlrBase { -public: - Terminator(llvm::StringRef name, llvm::SMLoc loc, baseKind kind) - : AntlrBase(name, loc, kind) {} - static bool classof(const AntlrBase *base) { - return base->getKind() == baseKind::terminator; - } -}; -/// The class is used to mark the node in the generator as regular expressions. -class PBExpression : public AntlrBase { -public: - PBExpression(llvm::StringRef name, llvm::SMLoc loc, baseKind kind) - : AntlrBase(name, loc, kind) {} - static bool classof(const AntlrBase *base) { - return base->getKind() == baseKind::terminator; - } -}; - -/// The class is used to store the information about Dialect class in the -/// TableGen. -class Dialect { - llvm::StringRef defName; - llvm::StringRef name; - llvm::StringRef cppNamespace; - -public: - Dialect() {} - llvm::StringRef getName() { return name; } - llvm::StringRef getCppNamespace() { return cppNamespace; } - llvm::StringRef getDefName() { return defName; } - void setName(llvm::StringRef name) { this->name = name; } - void setDefName(llvm::StringRef defName) { this->defName = defName; } - void setCppNamespace(llvm::StringRef cppNamespace) { - this->cppNamespace = cppNamespace; - } -}; - -class DAG { - llvm::StringRef dagOperator; - llvm::SmallVector operands; - llvm::SmallVector operandNames; - llvm::StringMap values; - -public: - DAG(){}; - DAG(const DAG &dag) { - this->dagOperator = dag.dagOperator; - this->operands = dag.operands; - this->operandNames = dag.operandNames; - this->values = dag.values; - } - - void addOperand(llvm::StringRef operand, llvm::StringRef operandName) { - operands.push_back(operand); - operandNames.push_back(operandName); - } - void setValue(llvm::StringRef operand, llvm::StringRef value) { - values[operand] = value; - } - llvm::StringRef findValue(llvm::StringRef operand) { - if (values.find(operand) == values.end()) - return llvm::StringRef(); - return values[operand]; - } - llvm::StringRef getDagOperater() { return dagOperator; } - void setDagOperatpr(llvm::StringRef dagOperator) { - this->dagOperator = dagOperator; - } - llvm::SmallVector getOperands() { return operands; } - llvm::SmallVector getOperandNames() { - return operandNames; - } -}; -/// The class is used to store builder in Op class. -class Builder { - DAG *dag = nullptr; - llvm::StringRef code; - -public: - Builder(DAG *dag, llvm::StringRef code) { - this->dag = dag; - this->code = code; - } - DAG *getDag() { return dag; } - llvm::StringRef getCode() { return code; } -}; - -/// The class is used to store information about Op class in the TableGen. -class Op { - llvm::StringRef opName; - DAG *arguments; - DAG *results; - std::vector builders; - -public: - llvm::StringRef getOpName() { return opName; } - DAG *getArguments() { return arguments; } - DAG *getResults() { return results; } - std::vector getBuilders() { return builders; } - - void setOpName(llvm::StringRef opName) { this->opName = opName; } - - void setArguments(DAG *arguments) { this->arguments = arguments; } - void setResults(DAG *results) { this->results = results; } - void setBuilders(std::vector &builders) { - this->builders = builders; - } -}; - -/// This class will become the root of a tree which contains all information we -/// need to generate code. -class Module { - std::vector rules; - Dialect *dialect; - std::vector ops; - -public: - std::vector &getRules() { return rules; } - Dialect *getDialect() { return dialect; } - std::vector &getOps() { return ops; } - void setRules(std::vector &rules) { this->rules = rules; } - void seDialect(Dialect *&dialect) { this->dialect = dialect; } - void setOps(std::vector &ops) { this->ops = ops; } -}; - -} // namespace frontendgen -#endif diff --git a/frontend/FrontendGen/include/CGModule.h b/frontend/FrontendGen/include/CGModule.h deleted file mode 100644 index 7fc769d94..000000000 --- a/frontend/FrontendGen/include/CGModule.h +++ /dev/null @@ -1,73 +0,0 @@ -//====- CGModule.h -------------------------------------------------------===// -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// - -#ifndef INCLUDE_CGMODULE_H -#define INCLUDE_CGMODULE_H -#include "AST.h" -#include "Terminator.h" -#include "llvm/Support/raw_ostream.h" -namespace frontendgen { - -/// TypeMap is used to store type maps.The cppMap is used to map c++ types, -/// argumentsMap and resultsMap are used to map TableGen types. -class TypeMap { - llvm::StringMap cppMap; - llvm::StringMap argumentsMap; - llvm::StringMap resultsMap; - -public: - TypeMap() { -#define CPPMAP(key, value) cppMap.insert(std::pair(key, value)); -#define RESULTSMAP(key, value) resultsMap.insert(std::pair(key, value)); -#define ARGUMENTSMAP(key, value) argumentsMap.insert(std::pair(key, value)); -#include "TypeMap.def" - } - llvm::StringRef findCppMap(llvm::StringRef value); - llvm::StringRef findArgumentMap(llvm::StringRef value); - llvm::StringRef findResultsMap(llvm::StringRef value); -}; - -/// The class for code generation. -class CGModule { - Terminators &terminators; - Module *module; - llvm::raw_fd_ostream &os; - TypeMap typeMap; - -public: - CGModule(Module *module, llvm::raw_fd_ostream &os, Terminators &terminators) - : terminators(terminators), module(module), os(os) {} - void emitAST(); - void emitAntlr(llvm::StringRef grammarName); - void emit(const std::vector &rules); - void emit(const std::vector &generators); - void emit(const std::vector &generator); - void emitGrammar(llvm::StringRef grammarName); - void emitTerminators(); - void emitCustomTerminators(); - void emitWSAndComment(); - void emitIncludes(llvm::StringRef grammarName); - void emitMLIRVisitor(llvm::StringRef grammarName); - void emitClass(llvm::StringRef grammarName); - void emitRuleVisitor(llvm::StringRef grammarName, Rule *rule); - void emitBuilders(Rule *rule); - void emitBuilder(llvm::StringRef builderOp, int index); - Op *findOp(llvm::StringRef opName); - void emitOp(Op *op, int index); -}; -} // namespace frontendgen - -#endif diff --git a/frontend/FrontendGen/include/Diagnostics.def b/frontend/FrontendGen/include/Diagnostics.def deleted file mode 100644 index c2afe0f63..000000000 --- a/frontend/FrontendGen/include/Diagnostics.def +++ /dev/null @@ -1,10 +0,0 @@ -#ifndef DIAG -#define DIAG(ID, Level, Msg) -#endif -DIAG(err_expected, Error, "expected {0} but found {1}") -DIAG(err_no_mnemonic, Warning, "you should indicate mnemonic.") -DIAG(err_not_supported_element, Error, "the {0} is not supported." ) -DIAG(err_no_name, Error, "opinterface should indicate the interface name.") -DIAG(err_only_supported_builder, Error, "we are only support builder") -DIAG(err_builder_fail,Error, "builder indicate failed.") -#undef DIAG diff --git a/frontend/FrontendGen/include/Diagnostics.h b/frontend/FrontendGen/include/Diagnostics.h deleted file mode 100644 index 54e475d77..000000000 --- a/frontend/FrontendGen/include/Diagnostics.h +++ /dev/null @@ -1,52 +0,0 @@ -//====- Diagnostics.h -----------------------------------------------------===// -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// - -#ifndef INCLUDE_DIAGNOSTIC_H -#define INCLUDE_DIAGNOSTIC_H -#include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/SourceMgr.h" - -/// When there is an error in the user's code, we can diagnose the error through -/// the class. -namespace frontendgen { -class DiagnosticEngine { - llvm::SourceMgr &SrcMgr; - static const char *getDiagnosticText(unsigned diagID); - llvm::SourceMgr::DiagKind getDiagnosticKind(unsigned diagID); - bool hasReport = false; - -public: - enum diagKind { -#define DIAG(ID, Level, Msg) ID, -#include "Diagnostics.def" - }; - DiagnosticEngine(llvm::SourceMgr &SrcMgr) : SrcMgr(SrcMgr) {} - - template - void report(llvm::SMLoc loc, unsigned diagID, Args &&...arguments) { - if (!hasReport) { - std::string Msg = llvm::formatv(getDiagnosticText(diagID), - std::forward(arguments)...) - .str(); - SrcMgr.PrintMessage(loc, getDiagnosticKind(diagID), Msg); - hasReport = true; - } - } -}; - -} // namespace frontendgen - -#endif diff --git a/frontend/FrontendGen/include/FegenManager.h b/frontend/FrontendGen/include/FegenManager.h new file mode 100644 index 000000000..5d358f8ff --- /dev/null +++ b/frontend/FrontendGen/include/FegenManager.h @@ -0,0 +1,815 @@ +#ifndef FEGEN_MANAGER_H +#define FEGEN_MANAGER_H + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" + +#include "FegenParser.h" +#include "ParserRuleContext.h" + +#define FEGEN_PLACEHOLDER "Placeholder" +#define FEGEN_TYPE "Type" +#define FEGEN_TYPETEMPLATE "TypeTemplate" +#define FEGEN_INTEGER "Integer" +#define FEGEN_FLOATPOINT "FloatPoint" +#define FEGEN_STRING "String" +#define FEGEN_VECTOR "Vector" +#define FEGEN_TENSOR "Tensor" +#define FEGEN_LIST "List" +#define FEGEN_OPTINAL "Optional" +#define FEGEN_ANY "Any" +#define FEGEN_DIALECT_NAME "fegen_builtin" +#define FEGEN_NOT_IMPLEMENTED_ERROR false + +namespace fegen { +class Type; +class Manager; +class Value; +class RightValue; + +using TypePtr = std::shared_ptr; +using largestInt = long long int; +// binary operation + +enum class FegenOperator { + OR, + AND, + EQUAL, + NOT_EQUAL, + LESS, + LESS_EQUAL, + GREATER, + GREATER_EQUAL, + ADD, + SUB, + MUL, + DIV, + MOD, + POWER, + NEG, + NOT +}; + +// user defined function +class Function { +private: + // cpp function name + std::string name; + // input object + std::vector inputTypeList; + // return type + TypePtr returnType; + explicit Function(std::string name, std::vector &&inputTypeList, + TypePtr returnType); + +public: + static Function *get(std::string name, std::vector inputTypeList, + TypePtr returnType = nullptr); + ~Function() = default; + std::string getName(); + std::vector &getInputTypeList(); + Value *getInputTypeList(size_t i); + TypePtr getReturnType(); +}; + +class Value; + +// user defined operation +class Operation { +private: + std::string dialectName; + std::string operationName; + // arguments of operation + std::vector arguments; + // results of operation + std::vector results; + // operation body context + FegenParser::BodySpecContext *ctx; + explicit Operation(std::string dialectName, std::string operationName, + std::vector &&arguments, + std::vector &&results, + FegenParser::BodySpecContext *ctx); + +public: + void setOpName(std::string); + std::string getOpName(); + std::vector &getArguments(); + Value *getArguments(size_t i); + std::vector &getResults(); + Value *getResults(size_t i); + static Operation *get(std::string operationName, + std::vector arguments, + std::vector results, + FegenParser::BodySpecContext *ctx); + ~Operation() = default; +}; + +class TypeDefination; +class RightValue; +class Type { + friend class Value; + +public: + enum class TypeKind { ATTRIBUTE, OPERAND, CPP }; + +private: + TypeKind kind; + std::string typeName; + // std::vector parameters; + TypeDefination *typeDefine; + int typeLevel; + bool isConstType; + +public: + Type(TypeKind kind, std::string name, TypeDefination *tyDef, int typeLevel, + bool isConstType); + + Type(const Type &) = default; + Type(Type &&) = default; + TypeKind getTypeKind(); + void setTypeKind(TypeKind kind); + TypeDefination *getTypeDefination(); + void setTypeDefination(TypeDefination *tyDef); + std::string getTypeName(); + int getTypeLevel(); + bool isConstant(); + // for generating typedef td file. + virtual std::string toStringForTypedef(); + // for generating op def td file. + virtual std::string toStringForOpdef(); + // for generating cpp type kind. + virtual std::string toStringForCppKind(); + static bool isSameType(TypePtr type1, TypePtr type2); + virtual ~Type() = default; + + // placeholder + static TypePtr getPlaceHolder(); + + // Type + static TypePtr getMetaType(); + + // TypeTemplate + static TypePtr getMetaTemplateType(); + + // int + static TypePtr getInt32Type(); + + // float + static TypePtr getFloatType(); + + // float + static TypePtr getDoubleType(); + + // bool + static TypePtr getBoolType(); + + // Integer + static TypePtr getIntegerType(RightValue size); + + // FloatPoint + static TypePtr getFloatPointType(RightValue size); + + // string + static TypePtr getStringType(); + + // List + static TypePtr getListType(TypePtr elementType); + static TypePtr getListType(RightValue elementType); + + // Vector + static TypePtr getVectorType(TypePtr elementType, RightValue size); + static TypePtr getVectorType(RightValue elementType, RightValue size); + + // Tensor + static TypePtr getTensorType(TypePtr elementType); + static TypePtr getTensorType(RightValue elementType); + + // Optional + static TypePtr getOptionalType(TypePtr elementType); + static TypePtr getOptionalType(RightValue elementType); + + // Any<[elementType1, elementType2, ...]> + static TypePtr getAnyType(RightValue elementTypes); + + static TypePtr getCustomeType(std::vector params, + TypeDefination *tydef); + + // Integer + static TypePtr getIntegerTemplate(); + + // FloatPoint + static TypePtr getFloatPointTemplate(); + + // string + static TypePtr getStringTemplate(); + + // List (elementType is template) + static TypePtr getListTemplate(TypePtr elementType); + static TypePtr getListTemplate(RightValue elementType); + + // Vector + static TypePtr getVectorTemplate(); + + // Tensor + static TypePtr getTensorTemplate(); + + // Optional (elementType is template) + static TypePtr getOptionalTemplate(TypePtr elementType); + static TypePtr getOptionalTemplate(RightValue elementType); + + // Any<[elementType1, elementType2, ...]> (elementType* is template) + static TypePtr getAnyTemplate(RightValue elementTypes); + + static TypePtr getCustomeTemplate(TypeDefination *tydef); +}; + +class TypeDefination { + friend class Manager; + +private: + std::string dialectName; + std::string name; + std::vector parameters; + FegenParser::TypeDefinationDeclContext *ctx; + bool ifCustome; + std::string mnemonic; + +public: + TypeDefination(std::string dialectName, std::string name, + std::vector parameters, + FegenParser::TypeDefinationDeclContext *ctx, bool ifCustome); + static TypeDefination *get(std::string dialectName, std::string name, + std::vector parameters, + FegenParser::TypeDefinationDeclContext *ctx, + bool ifCustome = true); + std::string getDialectName(); + void setDialectName(std::string); + std::string getName(); + std::string getMnemonic(); + void setName(std::string); + const std::vector &getParameters(); + FegenParser::TypeDefinationDeclContext *getCtx(); + void setCtx(FegenParser::TypeDefinationDeclContext *); + bool isCustome(); +}; + +/// @brief Represent right value, and pass by value. +class RightValue { + friend class Type; + friend class Value; + +public: + enum class LiteralKind { + MONOSTATE, + INT, + FLOAT, + STRING, + TYPE, + VECTOR, + LEFT_VAR, + FUNC_CALL, + OPERATION_CALL, + OPERATOR_CALL + }; + struct ExpressionNode; + struct FunctionCall; + struct OperationCall; + struct OperatorCall; + struct ExpressionTerminal; + struct Expression { + bool ifTerminal; + LiteralKind kind; + bool isLiteral; + bool ifConstexpr; + Expression(bool, LiteralKind, bool); + virtual ~Expression() = default; + virtual bool isTerminal(); + virtual std::string toString() = 0; + virtual std::string toStringForTypedef() = 0; + virtual std::string toStringForOpdef() = 0; + virtual std::string toStringForCppKind() = 0; + LiteralKind getKind(); + virtual TypePtr getType() = 0; + virtual std::any getContent() = 0; + virtual bool isConstexpr(); + + /// @brief operate lhs and rhs using binary operator. + static std::shared_ptr + binaryOperation(std::shared_ptr lhs, + std::shared_ptr rhs, FegenOperator op); + /// @brief operate expr using unary operator + static std::shared_ptr + unaryOperation(std::shared_ptr, FegenOperator); + + // TODO: callFunction + static std::shared_ptr + callFunction(std::vector>, Function *); + + // TODO: callOperation + static std::shared_ptr + callOperation(std::vector>, Operation *); + + static std::shared_ptr getPlaceHolder(); + static std::shared_ptr getInteger(largestInt, + size_t size = 32); + static std::shared_ptr getFloatPoint(long double, + size_t size = 32); + static std::shared_ptr getString(std::string); + static std::shared_ptr getTypeRightValue(TypePtr); + static std::shared_ptr + getList(std::vector> &); + static std::shared_ptr getLeftValue(fegen::Value *); + }; + + struct ExpressionNode : public Expression { + ExpressionNode(LiteralKind, bool); + virtual std::string toString() override; + virtual std::string toStringForTypedef() override; + virtual std::string toStringForOpdef() override; + virtual std::string toStringForCppKind() override; + virtual std::any getContent() override = 0; + virtual TypePtr getType() override; + }; + + struct FunctionCall : public ExpressionNode { + Function *func; + std::vector> params; + FunctionCall(Function *, std::vector>); + virtual std::string toString() override; + virtual std::string toStringForTypedef() override; + virtual std::string toStringForOpdef() override; + virtual std::string toStringForCppKind() override; + virtual std::any getContent() override; + virtual TypePtr getType() override; + }; + + struct OperationCall : public ExpressionNode { + Operation *op; + std::vector> params; + OperationCall(Operation *, std::vector>); + virtual std::string toString() override; + virtual std::string toStringForTypedef() override; + virtual std::string toStringForOpdef() override; + virtual std::string toStringForCppKind() override; + virtual std::any getContent() override; + virtual TypePtr getType() override; + }; + + struct OperatorCall : public ExpressionNode { + static std::unordered_map cppOperatorMap; + FegenOperator op; + std::vector> params; + OperatorCall(FegenOperator, std::vector>); + virtual std::string toString() override; + virtual std::string toStringForTypedef() override; + virtual std::string toStringForOpdef() override; + virtual std::string toStringForCppKind() override; + virtual std::any getContent() override; + virtual TypePtr getType() override; + }; + + struct ExpressionTerminal : public Expression { + ExpressionTerminal(LiteralKind, bool); + virtual std::string toString() override; + virtual std::string toStringForTypedef() override; + virtual std::string toStringForOpdef() override; + virtual std::string toStringForCppKind() override; + virtual std::any getContent() override = 0; + virtual TypePtr getType() override; + }; + + struct PlaceHolder : public ExpressionTerminal { + PlaceHolder(); + virtual std::any getContent() override; + virtual std::string toString() override; + }; + + struct IntegerLiteral : public ExpressionTerminal { + size_t size; + largestInt content; + IntegerLiteral(largestInt content, size_t size); + virtual std::any getContent() override; + virtual std::string toString() override; + virtual std::string toStringForCppKind() override; + virtual TypePtr getType() override; + }; + + struct FloatPointLiteral : public ExpressionTerminal { + size_t size; + long double content; + FloatPointLiteral(long double content, size_t size); + virtual std::any getContent() override; + virtual std::string toString() override; + virtual std::string toStringForCppKind() override; + virtual TypePtr getType() override; + }; + + struct StringLiteral : public ExpressionTerminal { + std::string content; + StringLiteral(std::string content); + virtual std::any getContent() override; + virtual std::string toString() override; + virtual std::string toStringForCppKind() override; + virtual TypePtr getType() override; + }; + + struct TypeLiteral : public ExpressionTerminal { + TypePtr content; + TypeLiteral(TypePtr content); + virtual std::any getContent() override; + virtual std::string toString() override; + virtual std::string toStringForTypedef() override; + virtual std::string toStringForOpdef() override; + virtual std::string toStringForCppKind() override; + virtual TypePtr getType() override; + }; + + struct ListLiteral : public ExpressionTerminal { + std::vector> content; + ListLiteral(std::vector> &content); + virtual std::any getContent() override; + virtual std::string toString() override; + virtual std::string toStringForTypedef() override; + virtual std::string toStringForOpdef() override; + virtual TypePtr getType() override; + }; + + struct LeftValue : public ExpressionTerminal { + Value *content; + LeftValue(Value *content); + virtual std::any getContent() override; + virtual std::string toString() override; + virtual std::string toStringForCppKind() override; + virtual TypePtr getType() override; + }; + +public: + using ExprPtr = std::shared_ptr; + RightValue(std::shared_ptr); + RightValue(const RightValue &) = default; + RightValue(RightValue &&) = default; + RightValue &operator=(const RightValue &another) = default; + RightValue::LiteralKind getLiteralKind(); + std::string toString(); + std::string toStringForTypedef(); + std::string toStringForOpdef(); + std::string toStringForCppKind(); + std::any getContent(); + TypePtr getType(); + std::shared_ptr getExpr(); + bool isConstant(); + + static RightValue getPlaceHolder(); + static RightValue getInteger(largestInt content, size_t size = 32); + static RightValue getFloatPoint(long double content, size_t size = 32); + static RightValue getString(std::string content); + static RightValue getTypeRightValue(TypePtr content); + static RightValue getList(std::vector> &content); + static RightValue getLeftValue(fegen::Value *content); + static RightValue getByExpr(std::shared_ptr expr); + ~RightValue() = default; + +private: + std::shared_ptr content; +}; + +// PlaceHolder +class PlaceHolderType : public Type { +public: + PlaceHolderType(); +}; + +// Type +class MetaType : public Type { +public: + MetaType(); + // for generating typedef td file. + virtual std::string toStringForTypedef() override; +}; +// Template +class MetaTemplate : public Type { +public: + MetaTemplate(); +}; +// Integer +class IntegerType : public Type { + RightValue size; + +public: + IntegerType(RightValue size, TypeDefination *tyDef); + IntegerType(RightValue size); + largestInt getSize(); + // for generating typedef td file. + virtual std::string toStringForTypedef() override; + // for generating op def td file. + virtual std::string toStringForOpdef() override; + // for generating cpp type kind. + virtual std::string toStringForCppKind() override; +}; +// FloatPoint +class FloatPointType : public Type { + RightValue size; + +public: + FloatPointType(RightValue size); + largestInt getSize(); + // for generating typedef td file. + virtual std::string toStringForTypedef() override; + // for generating op def td file. + virtual std::string toStringForOpdef() override; + // for generating cpp type kind. + virtual std::string toStringForCppKind() override; +}; +// String +class StringType : public Type { +public: + StringType(); + // for generating cpp type kind. + virtual std::string toStringForCppKind() override; +}; +// List +class ListType : public Type { + RightValue elementType; + +public: + ListType(RightValue elementType); + // for generating typedef td file. + virtual std::string toStringForTypedef() override; + // for generating op def td file. + virtual std::string toStringForOpdef() override; + // for generating cpp type kind. + virtual std::string toStringForCppKind() override; +}; +// Vector +class VectorType : public Type { + RightValue elementType; + RightValue size; + +public: + VectorType(RightValue elementType, RightValue size); +}; +// Tensor +class TensorType : public Type { + RightValue elementType; + public: + TensorType(RightValue elementType); + virtual std::string toStringForOpdef() override; +}; +// Optional +class OptionalType : public Type { + RightValue elementType; + +public: + OptionalType(RightValue elementType); +}; +// Any<[ty1, ty2, ...]> +class AnyType : public Type { + RightValue elementTypes; + +public: + AnyType(RightValue elementTypes); +}; +// custome type +class CustomeType : public Type { + std::vector params; + +public: + CustomeType(std::vector params, TypeDefination *tydef); +}; + +class TemplateType : public Type { +public: + TemplateType(TypeDefination *tydef); + virtual TypePtr instantiate(std::vector params) = 0; + virtual ~TemplateType() = default; +}; + +// Integer +class IntegerTemplateType : public TemplateType { +public: + IntegerTemplateType(); + virtual TypePtr instantiate(std::vector params) override; + // for generating typedef td file. + virtual std::string toStringForTypedef() override; + // for generating op def td file. + virtual std::string toStringForOpdef() override; +}; +// FloatPoint +class FloatPointTemplateType : public TemplateType { +public: + FloatPointTemplateType(); + virtual TypePtr instantiate(std::vector params) override; + // for generating typedef td file. + virtual std::string toStringForTypedef() override; +}; +// String +class StringTemplateType : public TemplateType { +public: + StringTemplateType(); + virtual TypePtr instantiate(std::vector params) override; + // for generating typedef td file. + virtual std::string toStringForTypedef() override; +}; +// List (ty is a template) +class ListTemplateType : public TemplateType { + RightValue elementType; + +public: + ListTemplateType(RightValue elementType); + virtual TypePtr instantiate(std::vector params) override; + virtual std::string toStringForTypedef() override; + virtual std::string toStringForOpdef() override; +}; +// Vector +class VectorTemplateType : public TemplateType { +public: + VectorTemplateType(); + virtual TypePtr instantiate(std::vector params) override; +}; +// Tensor +class TensorTemplateType : public TemplateType { +public: + TensorTemplateType(); + virtual TypePtr instantiate(std::vector params) override; +}; +// Optional (ty is a template) +class OptionalTemplateType : public TemplateType { + RightValue elementType; + +public: + OptionalTemplateType(RightValue elementType); + virtual TypePtr instantiate(std::vector params) override; +}; +// Any<[ty1, ty2, ...]> (ty* is a template) +class AnyTemplateType : public TemplateType { + RightValue elementTypes; + +public: + AnyTemplateType(RightValue elementTypes); + virtual TypePtr instantiate(std::vector params) override; +}; +// custome type +class CustomeTemplateType : public TemplateType { +public: + CustomeTemplateType(TypeDefination *tydef); + virtual TypePtr instantiate(std::vector params) override; + // for generating typedef td file. + virtual std::string toStringForTypedef() override; + // for generating op def td file. + virtual std::string toStringForOpdef() override; +}; + +class Value { + friend class Type; + +private: + TypePtr type; + std::string name; + RightValue content; + +public: + Value(TypePtr type, std::string name, RightValue content); + Value(const Value &rhs); + Value(Value &&rhs); + + static Value *get(TypePtr type, std::string name, RightValue constant); + + std::string getName(); + TypePtr getType(); + /// @brief return content of right value, get ExprssionNode* if kind is + /// EXPRESSION. + template T getContent() { + return std::any_cast(this->content.getContent()); + } + void setContent(fegen::RightValue content); + RightValue::LiteralKind getContentKind(); + std::string getContentString(); + std::string getContentStringForTypedef(); + std::string getContentStringForOpdef(); + std::string getContentStringForCppKind(); + std::shared_ptr getExpr(); + ~Value() = default; +}; + +class ParserNode; + +class ParserRule { + friend class Manager; + +private: + std::string content; + // from which node + ParserNode *src; + std::map inputs; + std::map returns; + // context in parser tree + antlr4::ParserRuleContext *ctx; + explicit ParserRule(std::string content, ParserNode *src, + antlr4::ParserRuleContext *ctx); + +public: + static ParserRule *get(std::string content, ParserNode *src, + antlr4::ParserRuleContext *ctx); + llvm::StringRef getContent(); + // check and add input value + bool addInput(Value input); + // check and add return value + bool addReturn(Value output); + // set source node + void setSrc(ParserNode *src); +}; + +class ParserNode { + friend class Manager; + +public: + enum class NodeType { PARSER_RULE, LEXER_RULE }; + +private: + std::vector rules; + antlr4::ParserRuleContext *ctx; + NodeType ntype; + explicit ParserNode(std::vector &&rules, + antlr4::ParserRuleContext *ctx, NodeType ntype); + +public: + static ParserNode *get(std::vector rules, + antlr4::ParserRuleContext *ctx, NodeType ntype); + static ParserNode *get(antlr4::ParserRuleContext *ctx, NodeType ntype); + void addFegenRule(ParserRule *rule); + // release rules first + ~ParserNode(); +}; + +class FegenVisitor; + +class Manager { + friend class FegenVisitor; + +private: + struct OverloadedType { + llvm::SmallVector tys; + OverloadedType(TypeDefination *); + OverloadedType(std::initializer_list &&); + TypeDefination *get(unsigned i); + }; + +private: + std::map typeDefMap; + Manager(); + Manager(const Manager &) = delete; + const Manager &operator=(const Manager &) = delete; + // release nodes, type, operation, function + ~Manager(); + void initbuiltinTypes(); + +public: + std::string moduleName; + std::vector headFiles; + std::map nodeMap; + llvm::StringMap typeMap; + + std::map operationMap; + std::map functionMap; + // stmt contents + std::unordered_map stmtContentMap; + void addStmtContent(antlr4::ParserRuleContext *ctx, std::any content); + template T getStmtContent(antlr4::ParserRuleContext *ctx) { + assert(this->stmtContentMap.count(ctx)); + return std::any_cast(this->stmtContentMap[ctx]); + } + + static Manager &getManager(); + void setModuleName(std::string name); + + TypeDefination *getTypeDefination(std::string name); + TypeDefination *getOverloadedTypeDefination(std::string name); + bool addTypeDefination(TypeDefination *tyDef); + bool addOverloadedTypeDefination(TypeDefination *tyDef); + + Operation *getOperationDefination(std::string name); + bool addOperationDefination(Operation *opDef); + void emitG4(); + void emitTypeDefination(); + void emitOpDefination(); + void emitDialectDefination(); + void emitTdFiles(); + void emitBuiltinFunction(fegen::FegenParser::FegenSpecContext *); +}; + +TypePtr inferenceType(std::vector>, + FegenOperator); + +} // namespace fegen + +#endif diff --git a/frontend/FrontendGen/include/FegenVisitor.h b/frontend/FrontendGen/include/FegenVisitor.h new file mode 100644 index 000000000..06b6416d5 --- /dev/null +++ b/frontend/FrontendGen/include/FegenVisitor.h @@ -0,0 +1,854 @@ +#ifndef FEGEN_FEGENVISITOR_H +#define FEGEN_FEGENVISITOR_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" + +#include "FegenManager.h" +#include "FegenParser.h" +#include "FegenParserBaseVisitor.h" +#include "Scope.h" + +using namespace antlr4; + +namespace fegen { + +/// @brief check if params are right. +/// @param expected expected params. +/// @param actual actual params. +/// @return true if correct. +bool checkParams(std::vector &expected, + std::vector &actual); + +/// @brief check if the type of elements in list are correct. +bool checkListLiteral( + std::vector> &listLiteral); + +class FegenVisitor : public FegenParserBaseVisitor { +private: + Manager &manager; + ScopeStack &sstack; + +public: + void emitG4() { this->manager.emitG4(); } + void emitTypeDefination() { this->manager.emitTypeDefination(); } + void emitDialectDefination() { this->manager.emitDialectDefination(); } + void emitOpDefination() { this->manager.emitOpDefination(); } + void emitBuiltinFunction(fegen::FegenParser::FegenSpecContext *moduleAST) { + this->manager.emitBuiltinFunction(moduleAST); + } + + FegenVisitor() + : manager(Manager::getManager()), sstack(ScopeStack::getScopeStack()) { + this->manager.initbuiltinTypes(); + } + + std::any visitTypeDefinationDecl( + FegenParser::TypeDefinationDeclContext *ctx) override { + auto typeName = ctx->typeDefinationName()->getText(); + auto tyDef = std::any_cast( + this->visit(ctx->typeDefinationBlock())); + // set name and ctx for type defination + tyDef->setName(typeName); + tyDef->setCtx(ctx); + // add defination to manager map + this->manager.addTypeDefination(tyDef); + return nullptr; + } + + // return FegenTypeDefination* + std::any visitTypeDefinationBlock( + FegenParser::TypeDefinationBlockContext *ctx) override { + auto params = + std::any_cast>(this->visit(ctx->parametersSpec())); + auto tyDef = + TypeDefination::get(this->manager.moduleName, "", params, nullptr); + return tyDef; + } + + std::any visitFegenDecl(FegenParser::FegenDeclContext *ctx) override { + this->manager.setModuleName(ctx->identifier()->getText()); + return nullptr; + } + + std::any + visitParserRuleSpec(FegenParser::ParserRuleSpecContext *ctx) override { + auto ruleList = + std::any_cast>(this->visit(ctx->ruleBlock())); + auto ruleNode = + ParserNode::get(ruleList, ctx, ParserNode::NodeType::PARSER_RULE); + // set source node for rules + for (auto rule : ruleList) { + rule->setSrc(ruleNode); + } + this->manager.nodeMap.insert({ctx->ParserRuleName()->getText(), ruleNode}); + return nullptr; + } + + std::any visitRuleAltList(FegenParser::RuleAltListContext *ctx) override { + std::vector ruleList; + for (auto alt : ctx->actionAlt()) { + auto fegenRule = std::any_cast(this->visit(alt)); + ruleList.push_back(fegenRule); + } + return ruleList; + } + + std::any visitActionAlt(FegenParser::ActionAltContext *ctx) override { + auto rawRule = this->visit(ctx->alternative()); + if (ctx->actionBlock()) { + auto blockValues = + std::any_cast, std::vector>>( + this->visit(ctx->actionBlock())); + auto inputs = std::get<0>(blockValues); + auto returns = std::get<1>(blockValues); + auto rule = std::any_cast(rawRule); + for (auto in : inputs) { + auto flag = rule->addInput(*in); + if (!flag) { // TODO: error report + std::cerr << "input of " << rule->getContent().str() << " \"" + << in->getName() << "\" existed." << std::endl; + } + } + for (auto out : returns) { + auto flag = rule->addReturn(*out); + if (!flag) { // TODO: error report + std::cerr << "return of " << rule->getContent().str() << " \"" + << out->getName() << "\" existed." << std::endl; + } + } + } + return rawRule; + } + + // return tuple, vector> + std::any visitActionBlock(FegenParser::ActionBlockContext *ctx) override { + std::vector inputs; + std::vector returns; + if (ctx->inputsSpec()) { + inputs = + std::any_cast>(this->visit(ctx->inputsSpec())); + } + + if (ctx->returnsSpec()) { + returns = + std::any_cast>(this->visit(ctx->returnsSpec())); + } + + if (ctx->actionSpec()) { + this->visit(ctx->actionSpec()); + } + return std::tuple(inputs, returns); + } + + // return FegenRule Object + // TODO: do more check + std::any visitAlternative(FegenParser::AlternativeContext *ctx) override { + auto content = ctx->getText(); + auto rule = ParserRule::get(content, nullptr, ctx); + return rule; + } + + std::any visitLexerRuleSpec(FegenParser::LexerRuleSpecContext *ctx) override { + // create node, get rules from child, and insert to node map + auto ruleList = std::any_cast>( + this->visit(ctx->lexerRuleBlock())); + auto ruleNode = + ParserNode::get(ruleList, ctx, ParserNode::NodeType::LEXER_RULE); + // set source node for rules + for (auto rule : ruleList) { + rule->setSrc(ruleNode); + } + this->manager.nodeMap.insert({ctx->LexerRuleName()->getText(), ruleNode}); + return nullptr; + } + + std::any visitLexerAltList(FegenParser::LexerAltListContext *ctx) override { + std::vector ruleList; + for (auto alt : ctx->lexerAlt()) { + auto rule = fegen::ParserRule::get(alt->getText(), nullptr, alt); + ruleList.push_back(rule); + } + return ruleList; + } + + // return vector + std::any visitVarDecls(FegenParser::VarDeclsContext *ctx) override { + size_t varCount = ctx->typeSpec().size(); + std::vector valueList; + for (size_t i = 0; i <= varCount - 1; i++) { + auto ty = std::any_cast(this->visit(ctx->typeSpec(i))); + auto varName = ctx->identifier(i)->getText(); + auto var = + fegen::Value::get(ty, varName, fegen::RightValue::getPlaceHolder()); + valueList.push_back(var); + } + return valueList; + } + + // return fegen::TypePtr + std::any + visitTypeInstanceSpec(FegenParser::TypeInstanceSpecContext *ctx) override { + auto valueKind = ctx->valueKind() ? std::any_cast( + this->visit(ctx->valueKind())) + : fegen::Type::TypeKind::CPP; + auto typeInst = + std::any_cast(this->visit(ctx->typeInstance())); + typeInst->setTypeKind(valueKind); + return typeInst; + } + + // return fegen::FegenType::TypeKind + std::any visitValueKind(FegenParser::ValueKindContext *ctx) override { + auto kind = fegen::Type::TypeKind::ATTRIBUTE; + if (ctx->CPP()) { + kind = fegen::Type::TypeKind::CPP; + } else if (ctx->OPERAND()) { + kind = fegen::Type::TypeKind::OPERAND; + } + // otherwise: ATTRIBUTE + return kind; + } + + // return fegen::TypePtr + std::any visitTypeInstance(FegenParser::TypeInstanceContext *ctx) override { + if (ctx->typeTemplate()) { // typeTemplate (Less typeTemplateParam (Comma + // typeTemplateParam)* Greater)? + auto typeTeplt = + std::any_cast(this->visit(ctx->typeTemplate())); + if (ctx->typeTemplate()->TYPE()) { + return typeTeplt; + } + auto teplt = std::dynamic_pointer_cast(typeTeplt); + // get parameters + std::vector paramList; + for (auto paramCtx : ctx->typeTemplateParam()) { + auto tepltParams = + std::any_cast(this->visit(paramCtx)); + paramList.push_back(tepltParams); + } + + // check parameters + auto expectedParams = teplt->getTypeDefination()->getParameters(); + if (!checkParams(expectedParams, paramList)) { + std::cerr << "parameters error in context: " << ctx->getText() + << std::endl; + exit(0); + } + // get instance + auto typeInst = teplt->instantiate(paramList); + return typeInst; + } else if (ctx->identifier()) { // identifier + auto varName = ctx->identifier()->getText(); + auto var = this->sstack.attemptFindVar(varName); + if (var) { + if (var->getContentKind() == fegen::RightValue::LiteralKind::TYPE) { + return var->getContent(); + } else { + std::cerr << "variable " << varName + << " is not a Type or TypeTemplate." << std::endl; + exit(0); + return nullptr; + } + } else { // variable does not exist. + std::cerr << "undefined variable: " << varName << std::endl; + exit(0); + return nullptr; + } + } else { // builtinTypeInstances + return visitChildren(ctx); + } + } + + // return RightValue + std::any + visitTypeTemplateParam(FegenParser::TypeTemplateParamContext *ctx) override { + if (ctx->builtinTypeInstances()) { + auto ty = std::any_cast( + this->visit(ctx->builtinTypeInstances())); + return fegen::RightValue::getTypeRightValue(ty); + } else { + auto expr = std::any_cast>( + this->visit(ctx->expression())); + return fegen::RightValue::getByExpr(expr); + } + } + + // return fegen::FegenType + std::any visitBuiltinTypeInstances( + FegenParser::BuiltinTypeInstancesContext *ctx) override { + if (ctx->BOOL()) { + return Type::getBoolType(); + } else if (ctx->INT()) { + return Type::getInt32Type(); + } else if (ctx->FLOAT()) { + return Type::getFloatType(); + } else if (ctx->DOUBLE()) { + return Type::getDoubleType(); + } else if (ctx->STRING()) { + return Type::getStringType(); + } else { + std::cerr << "error builtin type." << std::endl; + return nullptr; + } + } + + // return TypePtr + std::any visitTypeTemplate(FegenParser::TypeTemplateContext *ctx) override { + if (ctx->prefixedName()) { // prefixedName + if (ctx->prefixedName()->identifier().size() == 2) { // dialect.type + // TODO: return type from other dialect + return nullptr; + } else { // type + auto tyDef = this->manager.getTypeDefination( + ctx->prefixedName()->identifier(0)->getText()); + return fegen::Type::getCustomeTemplate(tyDef); + } + } else if (ctx->builtinTypeTemplate()) { // builtinTypeTemplate + return this->visit(ctx->builtinTypeTemplate()); + } else { // TYPE + return fegen::Type::getMetaType(); + } + } + + // return TypePtr + std::any visitBuiltinTypeTemplate( + FegenParser::BuiltinTypeTemplateContext *ctx) override { + if (ctx->INTEGER()) { + return fegen::Type::getIntegerTemplate(); + } else if (ctx->FLOATPOINT()) { + return fegen::Type::getFloatPointTemplate(); + } else if (ctx->TENSOR()) { + return fegen::Type::getTensorTemplate(); + } else if (ctx->VECTOR()) { + return fegen::Type::getVectorTemplate(); + } else { + return nullptr; + } + } + + // return TypePtr + std::any + visitCollectTypeSpec(FegenParser::CollectTypeSpecContext *ctx) override { + auto kind = fegen::Type::TypeKind::CPP; + if (ctx->valueKind()) { + kind = + std::any_cast(this->visit(ctx->valueKind())); + } + auto ty = std::any_cast(this->visit(ctx->collectType())); + ty->setTypeKind(kind); + return ty; + } + + // return TypePtr + std::any visitCollectType(FegenParser::CollectTypeContext *ctx) override { + auto expr = std::any_cast>( + this->visit(ctx->expression())); + + if (ctx->collectProtoType()->ANY()) { + // check to get list type. + std::vector tyexpr = + std::any_cast>(expr); + int level = std::any_cast(tyexpr[0]->getContent()) + ->getTypeLevel(); + for (size_t i = 1; i <= tyexpr.size() - 1; i++) { + auto expr = tyexpr[i]; + auto t = std::any_cast(expr->getContent()); + if (level != t->getTypeLevel()) { + assert(false); + } + } + if (level == 1 || level == 2) { // template -> any template + return fegen::Type::getAnyTemplate(fegen::RightValue::getByExpr(expr)); + } else { // instance -> any instance + return fegen::Type::getAnyType(fegen::RightValue::getByExpr(expr)); + } + } else if (ctx->collectProtoType()->LIST()) { + // the same as any + int level = + std::any_cast(expr->getContent())->getTypeLevel(); + if (level == 1 || level == 2) { + return fegen::Type::getListTemplate(fegen::RightValue::getByExpr(expr)); + } else { + return fegen::Type::getListType(fegen::RightValue::getByExpr(expr)); + } + } else { // optional + // the same as any + int level = + std::any_cast(expr->getContent())->getTypeLevel(); + if (level == 1 || level == 2) { + return fegen::Type::getOptionalTemplate( + fegen::RightValue::getByExpr(expr)); + } else { + return fegen::Type::getOptionalType(fegen::RightValue::getByExpr(expr)); + } + } + } + + // return std::shared_ptr + std::any visitExpression(FegenParser::ExpressionContext *ctx) override { + auto expr = std::any_cast>( + this->visit(ctx->andExpr(0))); + for (size_t i = 1; i <= ctx->andExpr().size() - 1; i++) { + auto rhs = std::any_cast>( + this->visit(ctx->andExpr(i))); + expr = RightValue::ExpressionNode::binaryOperation(expr, rhs, + FegenOperator::OR); + } + manager.addStmtContent(ctx, expr); + return expr; + } + + // return std::shared_ptr + std::any visitAndExpr(FegenParser::AndExprContext *ctx) override { + auto expr = std::any_cast>( + this->visit(ctx->equExpr(0))); + for (size_t i = 1; i <= ctx->equExpr().size() - 1; i++) { + auto rhs = std::any_cast>( + this->visit(ctx->equExpr(i))); + expr = RightValue::ExpressionNode::binaryOperation(expr, rhs, + FegenOperator::AND); + } + return expr; + } + + // return std::shared_ptr + std::any visitEquExpr(FegenParser::EquExprContext *ctx) override { + auto expr = std::any_cast>( + this->visit(ctx->compareExpr(0))); + for (size_t i = 1; i <= ctx->compareExpr().size() - 1; i++) { + FegenOperator op; + if (ctx->children[2 * i - 1]->getText() == "==") { + op = FegenOperator::EQUAL; + } else { + op = FegenOperator::NOT_EQUAL; + } + auto rhs = std::any_cast>( + this->visit(ctx->compareExpr(i))); + expr = RightValue::ExpressionNode::binaryOperation(expr, rhs, op); + } + return expr; + } + + // return std::shared_ptr + std::any visitCompareExpr(FegenParser::CompareExprContext *ctx) override { + auto expr = std::any_cast>( + this->visit(ctx->addExpr(0))); + for (size_t i = 1; i <= ctx->addExpr().size() - 1; i++) { + FegenOperator op; + auto opStr = ctx->children[2 * i - 1]->getText(); + if (opStr == "<") { + op = FegenOperator::LESS; + } else if (opStr == "<=") { + op = FegenOperator::LESS_EQUAL; + } else if (opStr == "<=") { + op = FegenOperator::LESS_EQUAL; + } else if (opStr == ">") { + op = FegenOperator::GREATER; + } else { + op = FegenOperator::GREATER_EQUAL; + } + auto rhs = std::any_cast>( + this->visit(ctx->addExpr(i))); + expr = RightValue::ExpressionNode::binaryOperation(expr, rhs, op); + } + return expr; + } + + // return std::shared_ptr + std::any visitAddExpr(FegenParser::AddExprContext *ctx) override { + auto expr = std::any_cast>( + this->visit(ctx->term(0))); + for (size_t i = 1; i <= ctx->term().size() - 1; i++) { + FegenOperator op; + auto opStr = ctx->children[2 * i - 1]->getText(); + if (opStr == "+") { + op = FegenOperator::ADD; + } else { + op = FegenOperator::SUB; + } + auto rhs = std::any_cast>( + this->visit(ctx->term(i))); + expr = RightValue::ExpressionNode::binaryOperation(expr, rhs, op); + } + return expr; + } + + // return std::shared_ptr + std::any visitTerm(FegenParser::TermContext *ctx) override { + auto expr = std::any_cast>( + this->visit(ctx->powerExpr(0))); + for (size_t i = 1; i <= ctx->powerExpr().size() - 1; i++) { + FegenOperator op; + auto opStr = ctx->children[2 * i - 1]->getText(); + if (opStr == "*") { + op = FegenOperator::MUL; + } else if (opStr == "/") { + op = FegenOperator::DIV; + } else { + op = FegenOperator::MOD; + } + auto rhs = std::any_cast>( + this->visit(ctx->powerExpr(i))); + expr = RightValue::ExpressionNode::binaryOperation(expr, rhs, op); + } + return expr; + } + + // return std::shared_ptr + std::any visitPowerExpr(FegenParser::PowerExprContext *ctx) override { + auto expr = std::any_cast>( + this->visit(ctx->unaryExpr(0))); + for (size_t i = 1; i <= ctx->unaryExpr().size() - 1; i++) { + auto rhs = std::any_cast>( + this->visit(ctx->unaryExpr(i))); + expr = RightValue::ExpressionNode::binaryOperation(expr, rhs, + FegenOperator::POWER); + } + return expr; + } + + // return std::shared_ptr + std::any visitUnaryExpr(FegenParser::UnaryExprContext *ctx) override { + if (ctx->children.size() == 1 || ctx->Plus()) { + return this->visit(ctx->primaryExpr()); + } + auto expr = std::any_cast>( + this->visit(ctx->primaryExpr())); + FegenOperator op; + if (ctx->Minus()) { + op = FegenOperator::NEG; + } else { + op = FegenOperator::NOT; + } + expr = RightValue::ExpressionNode::unaryOperation(expr, op); + return expr; + } + + // return std::shared_ptr + std::any visitParenSurroundedExpr( + FegenParser::ParenSurroundedExprContext *ctx) override { + return this->visit(ctx->expression()); + } + + // return std::shared_ptr + std::any visitPrimaryExpr(FegenParser::PrimaryExprContext *ctx) override { + if (ctx->identifier()) { + auto name = ctx->identifier()->getText(); + auto var = this->sstack.attemptFindVar(name); + if (var) { + return (std::shared_ptr) + fegen::RightValue::ExpressionTerminal::getLeftValue(var); + } else { + // TODO + auto tyDef = this->manager.getTypeDefination(name); + if (tyDef) { + auto tyVar = fegen::Type::getCustomeTemplate(tyDef); + return (std::shared_ptr) + fegen::RightValue::Expression::getTypeRightValue(tyVar); + } else { + // TODO: error report + std::cerr << "can not find variable: " << ctx->identifier()->getText() + << "." << std::endl; + assert(false); + return nullptr; + } + } + } else if (ctx->typeSpec()) { + auto ty = std::any_cast(this->visit(ctx->typeSpec())); + return (std::shared_ptr) + RightValue::Expression::getTypeRightValue(ty); + } else { // constant, functionCall, parenSurroundedExpr,contextMethodInvoke, + // and variableAccess + return this->visit(ctx->children[0]); + } + } + + // return std::shared_ptr + std::any visitIntLiteral(FegenParser::IntLiteralContext *ctx) override { + long long int number = std::stoi(ctx->getText()); + size_t size = 32; // TODO: Get size of number. + return (std::shared_ptr) + fegen::RightValue::Expression::getInteger(number, size); + } + + // return std::shared_ptr + std::any visitRealLiteral(FegenParser::RealLiteralContext *ctx) override { + long double number = std::stod(ctx->getText()); + size_t size = 32; // TODO: Get size of number. + return (std::shared_ptr) + fegen::RightValue::Expression::getFloatPoint(number, size); + } + + // return std::shared_ptr + std::any visitCharLiteral(FegenParser::CharLiteralContext *ctx) override { + std::string s = ctx->getText(); + // remove quotation marks + std::string strWithoutQuotation = s.substr(1, s.size() - 2); + return (std::shared_ptr) + fegen::RightValue::Expression::getString(strWithoutQuotation); + } + + // return std::shared_ptr + std::any visitBoolLiteral(FegenParser::BoolLiteralContext *ctx) override { + int content = 0; + if (ctx->getText() == "true") { + content = 1; + } + return (std::shared_ptr) + fegen::RightValue::Expression::getInteger(content, 1); + } + + // return std::shared_ptr + std::any visitListLiteral(FegenParser::ListLiteralContext *ctx) override { + std::vector> elements; + for (auto exprCtx : ctx->expression()) { + auto expr = std::any_cast>( + this->visit(exprCtx)); + elements.push_back(expr); + } + return (std::shared_ptr) + fegen::RightValue::Expression::getList(elements); + } + + std::any visitActionSpec(FegenParser::ActionSpecContext *ctx) override { + return nullptr; + } + + std::any visitFunctionDecl(FegenParser::FunctionDeclContext *ctx) override { + sstack.pushScope(); + auto returnType = + std::any_cast(this->visit(ctx->typeSpec())); + manager.addStmtContent(ctx, returnType); + auto functionName = + std::any_cast(this->visit(ctx->funcName())); + auto hasfunc = manager.functionMap.find(functionName); + if (hasfunc != manager.functionMap.end()) { + std::cerr << "The function name \" " << functionName + << "\" has already been used. Please use another name." + << std::endl; + exit(0); + return nullptr; + } + auto functionParams = std::any_cast>( + this->visit(ctx->funcParams())); + this->visit(ctx->statementBlock()); + + fegen::Function *function = + fegen::Function::get(functionName, functionParams, returnType); + manager.functionMap.insert(std::pair{functionName, function}); + sstack.popScope(); + return nullptr; + } + + std::any visitFuncName(FegenParser::FuncNameContext *ctx) override { + auto functionName = ctx->identifier()->getText(); + manager.addStmtContent(ctx, functionName); + return functionName; + } + + std::any visitFuncParams(FegenParser::FuncParamsContext *ctx) override { + std::vector paramsList = {}; + + for (size_t i = 0; i < ctx->typeSpec().size(); i++) { + auto paramType = + std::any_cast(this->visit(ctx->typeSpec(i))); + auto paramName = ctx->identifier(i)->getText(); + auto param = fegen::Value::get(paramType, paramName, + fegen::RightValue::getPlaceHolder()); + paramsList.push_back(param); + sstack.attemptAddVar(param); + } + manager.addStmtContent(ctx, paramsList); + return paramsList; + } + + std::any visitVarDeclStmt(FegenParser::VarDeclStmtContext *ctx) override { + auto varType = std::any_cast(this->visit(ctx->typeSpec())); + manager.addStmtContent(ctx, varType); + auto varName = ctx->identifier()->getText(); + fegen::Value *var; + if (ctx->expression()) { + auto varcontent = + std::any_cast>( + this->visit(ctx->expression())); + // TODO: 支持获取expression的type后,可正常使用 + // if (!fegen::Type::isSameType(var->getType(), varcontent->getType())) { + // std::cerr << "The variabel \" " << varName << "\" need \"" + // << varType.getTypeName() + // << " \" type rightvalue. Now the expression is " + // << varcontent->exprType.getTypeName() << "." << std::endl; + // exit(0); + // return nullptr; + // } + var = fegen::Value::get(varType, varName, + fegen::RightValue::getByExpr(varcontent)); + } else { + var = fegen::Value::get(varType, varName, + fegen::RightValue::getPlaceHolder()); + } + sstack.attemptAddVar(var); + + return var; + } + + std::any visitAssignStmt(FegenParser::AssignStmtContext *ctx) override { + auto varName = ctx->identifier()->getText(); + auto varcontent = + std::any_cast>( + this->visit(ctx->expression())); + auto var = sstack.attemptFindVar(varName); + + // TODO: 支持获取expression的type后,可正常使用 + // if (!fegen::Type::isSameType(var->getType(), varcontent->getType())) { + // std::cerr << "The variabel \" " << varName << "\" need \"" + // << var->getType()->toStringForCppKind() << " \" type + // rightvalue." + // << std::endl; + // exit(0); + // return nullptr; + // } + + fegen::Value *stmt = fegen::Value::get( + var->getType(), varName, fegen::RightValue::getByExpr(varcontent)); + manager.addStmtContent(ctx, stmt); + manager.addStmtContent(ctx->expression(), varcontent); + return var; + } + + // TODO:测试并补足函数调用 + std::any visitFunctionCall(FegenParser::FunctionCallContext *ctx) override { + std::vector> parasList = {}; + auto functionName = + std::any_cast(this->visit(ctx->funcName())); + auto hasFunc = manager.functionMap.at(functionName); + auto paramsNum = ctx->expression().size(); + auto paraList = hasFunc->getInputTypeList(); + if (paramsNum > 0) { + for (size_t i = 0; i < paramsNum; i++) { + auto oprand = + std::any_cast>( + this->visit(ctx->expression(i))); + parasList.push_back(oprand); + } + size_t len1 = paraList.size(); + size_t len2 = parasList.size(); + if (len1 != len2) { + std::cerr << "The function \" " << functionName + << "\" parameter count mismatch." << std::endl; + exit(0); + return nullptr; + } + + // TODO: check parameter type + // for (size_t i = 0; i < len1; i++) { + // if (!fegen::Type::isSameType(paraList[i]->getType(), + // parasList[i]->exprType)) { + // std::cerr << "The function \" " << functionName << "\" parameter" + // << i + // << " type mismatch." << std::endl; + // exit(0); + // return nullptr; + // } + // } + } + auto returnType = hasFunc->getReturnType(); + fegen::Function *funcCall = + fegen::Function::get(functionName, paraList, returnType); + manager.stmtContentMap.insert(std::pair{ctx, funcCall}); + return returnType; + } + + // TODO:add op invoke + std::any visitOpInvokeStmt(FegenParser::OpInvokeStmtContext *ctx) override { + return nullptr; + } + + std::any visitIfStmt(FegenParser::IfStmtContext *ctx) override { + for (size_t i = 0; i < ctx->ifBlock().size(); i++) { + this->visit(ctx->ifBlock(i)); + } + + if (ctx->elseBlock()) { + this->visit(ctx->elseBlock()); + } + return nullptr; + } + + std::any visitIfBlock(FegenParser::IfBlockContext *ctx) override { + sstack.pushScope(); + this->visit(ctx->expression()); + this->visit(ctx->statementBlock()); + sstack.popScope(); + + return nullptr; + } + + std::any visitElseBlock(FegenParser::ElseBlockContext *ctx) override { + this->sstack.pushScope(); + this->visit(ctx->statementBlock()); + this->sstack.popScope(); + return nullptr; + } + + std::any visitForStmt(FegenParser::ForStmtContext *ctx) override { + sstack.pushScope(); + if (ctx->varDeclStmt()) { + this->visit(ctx->varDeclStmt()); + this->visit(ctx->expression()); + this->visit(ctx->assignStmt(0)); + } else { + this->visit(ctx->assignStmt(0)); + this->visit(ctx->expression()); + this->visit(ctx->assignStmt(1)); + } + this->visit(ctx->statementBlock()); + sstack.popScope(); + + return nullptr; + } + + std::any visitReturnBlock(FegenParser::ReturnBlockContext *ctx) override { + this->visit(ctx->expression()); + return nullptr; + } + + std::any visitOpDecl(FegenParser::OpDeclContext *ctx) override { + auto opName = ctx->opName()->getText(); + auto opDef = std::any_cast(this->visit(ctx->opBlock())); + opDef->setOpName(opName); + bool success = this->manager.addOperationDefination(opDef); + if (!success) { + // TODO: error report + std::cerr << "operation " << opName << " already exist." << std::endl; + } + return nullptr; + } + + // return FegenOperation* + std::any visitOpBlock(FegenParser::OpBlockContext *ctx) override { + std::vector args; + std::vector res; + if (ctx->argumentSpec()) { + args = std::any_cast>( + this->visit(ctx->argumentSpec())); + } + if (ctx->resultSpec()) { + res = std::any_cast>( + this->visit(ctx->resultSpec())); + } + return fegen::Operation::get("", args, res, ctx->bodySpec()); + } +}; +} // namespace fegen +#endif diff --git a/frontend/FrontendGen/include/Lexer.h b/frontend/FrontendGen/include/Lexer.h deleted file mode 100644 index 4ec1a88ea..000000000 --- a/frontend/FrontendGen/include/Lexer.h +++ /dev/null @@ -1,59 +0,0 @@ -//====- Lexer.h ---------------------------------------------------------===// -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// - -#ifndef INCLUDE_LEXER_H -#define INCLUDE_LEXER_H -#include "Diagnostics.h" -#include "Token.h" -#include "llvm/ADT/StringMap.h" -#include "llvm/Support/SourceMgr.h" -namespace frontendgen { - -/// Manage all keywords. -class KeyWordManager { - llvm::StringMap keywordMap; - void addKeyWords(); - -public: - KeyWordManager() { addKeyWords(); } - void addKeyWord(llvm::StringRef name, tokenKinds kind); - tokenKinds getKeyWord(llvm::StringRef name, tokenKinds kind); -}; - -class Lexer { - llvm::SourceMgr &srcMgr; - DiagnosticEngine &diagnostic; - const char *curPtr; - llvm::StringRef curBuffer; - KeyWordManager keywordManager; - -public: - Lexer(llvm::SourceMgr &srcMgr, DiagnosticEngine &diagnostic) - : srcMgr(srcMgr), diagnostic(diagnostic) { - curBuffer = srcMgr.getMemoryBuffer(srcMgr.getMainFileID())->getBuffer(); - curPtr = curBuffer.begin(); - } - DiagnosticEngine &getDiagnostic() { return diagnostic; } - void next(Token &token); - void identifier(Token &token); - void number(Token &token); - void formToken(Token &token, const char *tokenEnd, tokenKinds kind); - llvm::StringRef getMarkContent(std::string start, std::string end); - llvm::StringRef getEndChContent(const char *start, char ch); -}; - -} // namespace frontendgen -#endif diff --git a/frontend/FrontendGen/include/Parser.h b/frontend/FrontendGen/include/Parser.h deleted file mode 100644 index 90ebb3a5b..000000000 --- a/frontend/FrontendGen/include/Parser.h +++ /dev/null @@ -1,60 +0,0 @@ -//====- Parser.h --------------------------------------------------------===// -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// - -#ifndef INCLUDE_PARSER_H -#define INCLUDE_PARSER_H -#include "AST.h" -#include "Lexer.h" -#include "Sema.h" -#include "Terminator.h" -#include "Token.h" -namespace frontendgen { - -/// A class for parsing tokens. -class Parser { - Lexer &lexer; - Token token; - Sema &action; - Terminators &terminators; - -public: - Parser(Lexer &lexer, Sema &action, Terminators &terminators) - : lexer(lexer), action(action), terminators(terminators) { - advance(); - } - bool consume(tokenKinds kind); - bool consumeNoAdvance(tokenKinds kind); - void advance(); - Module *parser(); - void compilEngine(Module *module); - void parserRules(Rule *rule); - void parserGenerator(GeneratorAndOthers *generatorAndOthers); - void lookToken(); - AntlrBase::baseKind getAntlrBaseKind(llvm::StringRef name); - void parserIdentifier(GeneratorAndOthers *generatorAndOthers); - void parserTerminator(GeneratorAndOthers *generatorAndOthers); - void parserPBExpression(GeneratorAndOthers *generatorAndOthers); - void parserDialect(Dialect *&dialect, llvm::StringRef defName); - bool parserOp(std::vector &ops, llvm::StringRef opName); - void parserCurlyBracketOpen(GeneratorAndOthers *generatorAndOthers); - void parserDAG(DAG *&dag); - void parserBuilders(std::vector &builders); - void parserCode(llvm::StringRef &code); - void parserCArg(llvm::StringRef &operand, llvm::StringRef &value); -}; -} // namespace frontendgen - -#endif diff --git a/frontend/FrontendGen/include/Scope.h b/frontend/FrontendGen/include/Scope.h new file mode 100644 index 000000000..56e5eecb2 --- /dev/null +++ b/frontend/FrontendGen/include/Scope.h @@ -0,0 +1,69 @@ +#ifndef FEGEN_SCOPE_H +#define FEGEN_SCOPE_H + +#include "FegenManager.h" +#include + +namespace fegen { + +template class SymbolTable { +private: + std::map table; + +public: + SymbolTable() = default; + void add(std::string, T *e); + T *get(std::string name); + /// @brief return true if name exist in map. + bool exist(std::string name); + ~SymbolTable(); +}; + +class FegenScope { + using VariableTable = SymbolTable; + friend class ScopeStack; + +private: + unsigned int scopeId; + FegenScope *parentScope; + VariableTable varTable; + +public: + explicit FegenScope(unsigned int scopeId, FegenScope *parentScope); + ~FegenScope() = default; + + /// @brief this will not check. + Value *findVar(std::string name); + /// @brief this will not check whether var is already existed or not. + void addVar(Value *var); + /// @brief return true if exist. + bool isExistVar(std::string name); +}; + +class ScopeStack { +private: + std::vector scopes; + std::stack scopeStack; + + FegenScope *currentScope; + FegenScope *globalScope; + // scope total count + size_t count; + + ScopeStack(); + ~ScopeStack(); + ScopeStack(const ScopeStack &) = delete; + const ScopeStack &operator=(const ScopeStack &) = delete; + +public: + static ScopeStack &getScopeStack(); + void pushScope(); + void popScope(); + /// @brief check and add var to current scope, return false if failed. + bool attemptAddVar(Value *var); + /// @brief check add find var from current scope, return nullptr if failed. + Value *attemptFindVar(std::string name); +}; +} // namespace fegen + +#endif \ No newline at end of file diff --git a/frontend/FrontendGen/include/Sema.h b/frontend/FrontendGen/include/Sema.h deleted file mode 100644 index e9d40881c..000000000 --- a/frontend/FrontendGen/include/Sema.h +++ /dev/null @@ -1,35 +0,0 @@ -//====- Sema.h ----------------------------------------------------------===// -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// - -#ifndef INCLUDE_SEMA_H -#define INCLUDE_SEMA_H -#include "AST.h" - -namespace frontendgen { - -class Sema { -public: - void actOnModule(Module *module, std::vector &rules, - Dialect *&dialect, std::vector &ops); - void actOnRule(Rule *rule, std::vector &generators); - void actOnDialect(Dialect *dialect, llvm::StringRef defName, - llvm::StringRef name, llvm::StringRef cppNamespace); - void actOnOps(std::vector &ops, llvm::StringRef opName, DAG *arguments, - DAG *results, std::vector &builder); - void actOnDag(DAG *&arguments, DAG &dag); -}; -} // namespace frontendgen -#endif diff --git a/frontend/FrontendGen/include/Terminator.def b/frontend/FrontendGen/include/Terminator.def deleted file mode 100644 index 4423f42bd..000000000 --- a/frontend/FrontendGen/include/Terminator.def +++ /dev/null @@ -1,22 +0,0 @@ -#ifndef terminator -#define terminator(NAME) -#endif - -terminator(Var, 'var') -terminator(Add, 'add') -terminator(Sub, 'sub') -terminator(Def, 'def') -terminator(Return, 'return') -terminator(ParentheseOpen, '(') -terminator(ParentheseClose, ')') -terminator(Comma, ',') -terminator(BracketOpen, '{') -terminator(BracketClose, '}') -terminator(SbracketOpen, '[') -terminator(SbracketClose, ']') -terminator(Semi, ';') -terminator(AngleBracketOpen, '<') -terminator(AngleBracketClose, '>') -terminator(Number, [0-9]+) -terminator(Equal, '=') -#undef terminator diff --git a/frontend/FrontendGen/include/Terminator.h b/frontend/FrontendGen/include/Terminator.h deleted file mode 100644 index 400ec45c3..000000000 --- a/frontend/FrontendGen/include/Terminator.h +++ /dev/null @@ -1,75 +0,0 @@ -//====- Terminator.h -------------------------------------------------------===// -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// - -#ifndef INCLUDE_TERMINATOR_H -#define INCLUDE_TERMINATOR_H -#include "llvm/ADT/SmallSet.h" -#include "llvm/ADT/StringMap.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/raw_ostream.h" - -namespace frontendgen { -class CGModule; -/// A class store antlr's terminators. -class Terminators { - friend class CGModule; - -private: - llvm::StringMap terminators; - llvm::SmallSet customTerminators; - -public: - Terminators() { -#define terminator(NAME, VALUE) terminators.insert(std::pair(#NAME, #VALUE)); -#include "Terminator.def" - } - /// Determine if it is a terminator. - bool isTerminator(llvm::StringRef terminator) { - std::string tmp = terminator.str(); - tmp[0] += 32; - if (customTerminators.contains(tmp)) - return true; - if (terminators.find(terminator) == terminators.end()) - return false; - return true; - } - - void addCustomTerminators(llvm::StringRef terminator) { - customTerminators.insert(terminator); - } - void addTerminator(llvm::StringRef terminator) { - terminators.insert(std::pair(terminator, terminator)); - } - /// Output all terminators. - void lookTerminators() { - llvm::outs() << "customTerminators\n"; - for (llvm::StringRef terminator : customTerminators) { - std::string terminatorName = terminator.str(); - terminatorName[0] -= 32; - llvm::outs() << "terminator name:" << terminatorName << ' ' - << "terminator content:" << terminator << '\n'; - } - for (auto start = terminators.begin(); start != terminators.end(); - ++start) { - llvm::outs() << "terminator name:" << start->first() << ' ' - << "terminator content:" << start->second << '\n'; - } - } -}; - -} // namespace frontendgen - -#endif diff --git a/frontend/FrontendGen/include/Token.def b/frontend/FrontendGen/include/Token.def deleted file mode 100644 index 253b5751b..000000000 --- a/frontend/FrontendGen/include/Token.def +++ /dev/null @@ -1,37 +0,0 @@ -#ifndef TOK -#define TOK(ID) -#endif -#ifndef PUNCTUATOR -#define PUNCTUATOR(ID, SP) TOK(ID) -#endif -#ifndef KEYWORD -#define KEYWORD(ID, FLAG) TOK(kw_ ## ID) -#endif -TOK(unknown) -TOK(eof) -TOK(identifier) -TOK(number) -PUNCTUATOR(semi, ";") -PUNCTUATOR(colon, ":") -PUNCTUATOR(apostrophe, "'") -PUNCTUATOR(asterisk, "*") -PUNCTUATOR(parentheseOpen, "(") -PUNCTUATOR(parentheseClose, ")") -PUNCTUATOR(questionMark, "?") -PUNCTUATOR(plus, "+") -PUNCTUATOR(equal, "=") -PUNCTUATOR(curlyBlacketOpen, "{") -PUNCTUATOR(curlyBlacketClose, "}") -PUNCTUATOR(dollar, "$") -PUNCTUATOR(comma, ",") -PUNCTUATOR(angleBracketOpen, "<") -PUNCTUATOR(angleBracketClose, ">") -PUNCTUATOR(squareBracketOpen, "[") -PUNCTUATOR(squareBracketClose, "]") -PUNCTUATOR(doubleQuotationMark, "\"") -KEYWORD(rule, KEYALL) -KEYWORD(op, KEYALL) -KEYWORD(dialect, KEYALL) -#undef TOK -#undef PUNCTUATOR -#undef KEYWORD diff --git a/frontend/FrontendGen/include/Token.h b/frontend/FrontendGen/include/Token.h deleted file mode 100644 index 96f322753..000000000 --- a/frontend/FrontendGen/include/Token.h +++ /dev/null @@ -1,56 +0,0 @@ -//====- Token.h --------------------------------------------------------===// -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// - -#ifndef INCLUDE_TOKEN -#define INCLUDE_TOKEN -#include "Lexer.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/SMLoc.h" -#include "llvm/Support/raw_ostream.h" -namespace frontendgen { -enum tokenKinds { -#define TOK(ID) ID, -#include "Token.def" - NUM_TOKENS -}; -/// store token names. -static const char *tokenNameMap[] = { -#define TOK(ID) #ID, -#define KEYWORD(ID, FLAG) #ID, -#include "Token.def" - nullptr}; - -class Token { - friend class Lexer; - -private: - tokenKinds tokenKind; - const char *start; - int length; - -public: - void setTokenKind(tokenKinds kind) { tokenKind = kind; } - void setLength(int len) { length = len; } - - llvm::StringRef getContent() { return llvm::StringRef(start, length); } - tokenKinds getKind() { return tokenKind; } - const char *getTokenName() { return tokenNameMap[tokenKind]; } - bool is(tokenKinds kind); - llvm::SMLoc getLocation(); -}; - -} // namespace frontendgen -#endif diff --git a/frontend/FrontendGen/include/TypeMap.def b/frontend/FrontendGen/include/TypeMap.def deleted file mode 100644 index 5d7862b8d..000000000 --- a/frontend/FrontendGen/include/TypeMap.def +++ /dev/null @@ -1,32 +0,0 @@ -#ifndef CPPMAP -#define CPPMAP(key, value) -#endif - -#ifndef ARGUMENTSMAP -#define ARGUMENTSMAP(key, value) -#endif - -#ifndef RESULTSMAP -#define RESULTSMAP(key, value) -#endif - -CPPMAP("\"StringRef\"", "llvm::StringRef") -CPPMAP("\"ArrayRef\"", "llvm::ArrayRef") -CPPMAP("\"FunctionType\"", "llvm::FunctionType") -CPPMAP("\"ArrayRef\"", "llvm::ArrayRef") -CPPMAP("\"Value\"", "mlir::Value") -CPPMAP("\"double\"", "double") -CPPMAP("\"DenseElementsAttr\"", "mlir::DenseElementsAttr") - -ARGUMENTSMAP("F64ElementsAttr", "mlir::Value") -ARGUMENTSMAP("F64Tensor", "mlir::Value") -ARGUMENTSMAP("Variadic", "mlir::Value") -ARGUMENTSMAP("SymbolNameAttr", "llvm::StringRef") -ARGUMENTSMAP("TypeAttrOf", "mlir::FunctionType") -ARGUMENTSMAP("F64MemRef", "mlir::Value") -RESULTSMAP("StaticShapeTensorOf<[F64]>", "mlir::Type") -RESULTSMAP("F64Tensor", "mlir::Type") - -#undef TYPEMAP -#undef ARGUMENTSMAP -#undef RESULTSMAP diff --git a/frontend/FrontendGen/lib/CGModule.cpp b/frontend/FrontendGen/lib/CGModule.cpp deleted file mode 100644 index 8d21ed652..000000000 --- a/frontend/FrontendGen/lib/CGModule.cpp +++ /dev/null @@ -1,422 +0,0 @@ -//====- CGModule.cpp ------------------------------------------------------===// -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// - -#include "CGModule.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Casting.h" -#include -#include - -using namespace frontendgen; - -/// Emit the ast,currently only antlr's ast are supported. -void CGModule::emitAST() { - for (auto i : module->getRules()) { - llvm::outs() << "rule name: " << i->getName() << '\n'; - for (auto j : i->getGeneratorsAndOthers()) { - llvm::outs() << " generator: " << '\n' << " "; - for (auto k : j->getGenerator()) { - if (k->getKind() == AntlrBase::baseKind::rule) - llvm::outs() << "\"" << k->getName() << "\"(rule) "; - else if (k->getKind() == AntlrBase::baseKind::terminator) - llvm::outs() << "\"" << k->getName() << "\"(terminator) "; - else if (k->getKind() == AntlrBase::baseKind::pbexpression) - llvm::outs() << "\"" << k->getName() << "\"(bpExpression) "; - } - llvm::outs() << '\n'; - } - } -} - -/// Emit the code of antlr , emit the generative formula first, then emit -/// user-defined terminator , and finally emit the system-defined terminator. -void CGModule::emitAntlr(llvm::StringRef grammarName) { - emitGrammar(grammarName); - emit(module->getRules()); - emitCustomTerminators(); - emitTerminators(); -} -/// Emit the system-defined terminator. -void CGModule::emitTerminators() { - for (auto start = terminators.terminators.begin(); - start != terminators.terminators.end(); start++) { - os << start->first() << '\n'; - os << " : " << start->second << "\n ;\n\n"; - } - emitWSAndComment(); -} - -void CGModule::emitGrammar(llvm::StringRef grammarName) { - os << "grammar " << grammarName << ";\n\n"; -} - -/// Emit user-defined terminator. -void CGModule::emitCustomTerminators() { - for (auto terminator : terminators.customTerminators) { - std::string tmp = terminator.str(); - if (tmp[0] >= 'a' && tmp[0] <= 'z') - tmp[0] -= 32; - llvm::StringRef name(tmp); - os << name << '\n'; - os << " : \'" << terminator.str() << "\'\n ;\n\n"; - } -} - -/// Emit the generative formula. -void CGModule::emit(const std::vector &rules) { - for (Rule *rule : rules) { - os << rule->getName() << '\n'; - emit(rule->getGeneratorsAndOthers()); - os << '\n'; - } -} - -/// Emit all generative formulas in a rule. -void CGModule::emit( - const std::vector &generatorsAndOthers) { - for (GeneratorAndOthers *generatorAndOthers : generatorsAndOthers) { - if (generatorAndOthers == generatorsAndOthers[0]) - os << " : "; - else - os << " | "; - emit(generatorAndOthers->getGenerator()); - } - os << " ;\n"; -} - -/// Output the elements of the generated formula. -void CGModule::emit(const std::vector &generator) { - for (AntlrBase *base : generator) { - if (base->getKind() == AntlrBase::baseKind::terminator) { - std::string tmp = base->getName().str(); - // The terminator in antlr must be capitalized. - if (tmp[0] >= 'a' && tmp[0] <= 'z') - tmp[0] -= 32; - llvm::StringRef name(tmp); - os << name << " "; - } else if (base->getKind() == AntlrBase::baseKind::rule) { - os << base->getName() << " "; - } else if (base->getKind() == AntlrBase::baseKind::pbexpression) { - os << base->getName(); - } - } - os << '\n'; -} - -/// TODO: Supports user-defined comment whitespace. -void CGModule::emitWSAndComment() { - os << "Identifier\n : [a-zA-Z][a-zA-Z0-9_]*\n ;\n\n"; - os << "WS\n : [ \\r\\n\\t] -> skip\n ;\n\n"; - os << "Comment\n : '#' .*? \'\\n\' ->skip\n ;\n"; -} - -void CGModule::emitMLIRVisitor(llvm::StringRef grammarName) { - emitIncludes(grammarName); - emitClass(grammarName); -} - -void CGModule::emitIncludes(llvm::StringRef grammarName) { - os << "#include \"" << grammarName << "BaseVisitor.h\"\n"; - os << "#include \"" << grammarName << "Lexer.h\"\n"; - os << "#include \"" << grammarName << "Parser.h\"\n"; - os << "#include \"mlir/IR/Attributes.h\"\n"; - os << "#include \"mlir/IR/Builders.h\"\n"; - os << "#include \"mlir/IR/BuiltinOps.h\"\n"; - os << "#include \"mlir/IR/BuiltinTypes.h\"\n"; - os << "#include \"mlir/IR/MLIRContext.h\"\n"; - os << "#include \"mlir/IR/Verifier.h\"\n"; - os << "#include \"llvm/ADT/STLExtras.h\"\n"; - os << "#include \"llvm/ADT/ScopedHashTable.h\"\n"; - os << "#include \"llvm/ADT/StringRef.h\"\n"; - os << "#include \"llvm/Support/raw_ostream.h\"\n"; - os << "\n"; -} - -/// Emit visitor class. -void CGModule::emitClass(llvm::StringRef grammarName) { - os << "class MLIR" << grammarName << "Visitor : public " << grammarName - << "BaseVisitor {\n"; - - os << "mlir::ModuleOp theModule;\n"; - os << "mlir::OpBuilder builder;\n"; - os << "std::string fileName;\n\n"; - - os << "public:\n"; - os << "MLIR" << grammarName - << "Visitor(std::string filename, mlir::MLIRContext &context)\n" - << ": builder(&context), fileName(filename) " - << "{\n theModule = mlir::ModuleOp::create(builder.getUnknownLoc()); " - "\n}\n\n"; - os << "mlir::ModuleOp getModule() { return theModule; }\n\n"; - - // Emit all virtual functions. - auto rules = module->getRules(); - for (auto rule : rules) { - emitRuleVisitor(grammarName, rule); - } - os << "};\n"; -} -/// Emit virtual function in antlr. -void CGModule::emitRuleVisitor(llvm::StringRef grammarName, Rule *rule) { - std::string ruleName = rule->getName().str(); - ruleName[0] = ruleName[0] - 32; - os << "virtual std::any visit" << ruleName; - os << "(" << grammarName << "Parser::" << ruleName << "Context *ctx) {\n"; - emitBuilders(rule); - os << " return visitChildren(ctx);\n"; - os << "}\n\n"; -} - -void CGModule::emitBuilders(Rule *rule) { - for (GeneratorAndOthers *generatorAndOthers : - rule->getGeneratorsAndOthers()) { - llvm::SmallVector builderOpNames = - generatorAndOthers->getBuilderNames(); - llvm::SmallVector indices = generatorAndOthers->getBuilderIndices(); - int size = builderOpNames.size(); - for (int start = 0; start < size; start++) - emitBuilder(builderOpNames[start], indices[start]); - } -} - -void CGModule::emitBuilder(llvm::StringRef builderOp, int index) { - Op *op = findOp(builderOp); - if (op == nullptr) { - llvm::errs() << builderOp << " is undefined!\n"; - return; - } - emitOp(op, index); -} - -Op *CGModule::findOp(llvm::StringRef opName) { - for (Op *op : module->getOps()) { - if (op->getOpName() == opName) - return op; - } - return nullptr; -} - -/// Emit the operation we want to create. -void CGModule::emitOp(Op *op, int index) { - // Emit the default builder function. - if (index == 0) { - DAG *arguments = op->getArguments(); - DAG *result = op->getResults(); - llvm::SmallVector argOperands; - llvm::SmallVector argOperandNames; - llvm::SmallVector resOperands; - llvm::SmallVector resOperandNames; - if (arguments) { - argOperands = arguments->getOperands(); - argOperandNames = arguments->getOperandNames(); - } - if (result) { - resOperands = result->getOperands(); - resOperandNames = result->getOperandNames(); - } - os << " {\n"; - // opArguments are used to store the names of the arguments needed to create - // the operation. - llvm::SmallVector opArguments; - // tmpStrings are used to store and own some computed string, keeping their - // lifetime longger than the StringRefs in opArguments - llvm::SmallVector tmpStrings; - // Emit variables for creation operation. - // Emit variables of result type. - for (size_t index = 0; index < resOperands.size(); index++) { - if (!typeMap.findResultsMap(resOperands[index]).empty()) { - os << " " << typeMap.findResultsMap(resOperands[index]) << " "; - if (!resOperandNames[index].empty()) { - os << resOperandNames[index] << ";\n"; - opArguments.push_back(resOperandNames[index]); - } else { - tmpStrings.emplace_back("res" + std::to_string(index)); - const auto &arg = tmpStrings.back(); - os << arg << ";\n"; - opArguments.push_back(arg); - } - } else if (resOperands[index].startswith("AnyTypeOf")) { - llvm::StringRef operand = resOperands[index]; - auto start = operand.find('[') + 1; - auto end = operand.find(']'); - auto cur = start; - if (start == std::string::npos || end == std::string::npos) { - return; - } - llvm::StringRef type; - while (cur <= end) { - if (operand[cur] == ',' || cur == end) { - std::string str(operand, start, cur - start); - str.erase(0, str.find_first_not_of(" ")); - str.erase(str.find_last_not_of(" ") + 1); - if (typeMap.findResultsMap(str).empty()) { - llvm::errs() << str << " in " << op->getOpName() - << " in results is not supported.\n"; - } - type = typeMap.findResultsMap(str); - start = cur + 1; - } - cur++; - } - os << " " << type << " "; - if (!resOperandNames[index].empty()) { - os << resOperandNames[index] << ";\n"; - opArguments.push_back(resOperandNames[index]); - } else { - tmpStrings.emplace_back("res" + std::to_string(index)); - const auto &arg = tmpStrings.back(); - os << arg << ";\n"; - opArguments.push_back(arg); - } - } else { - llvm::errs() << resOperands[index] << " in " << op->getOpName() - << " in results is not supported.\n"; - return; - } - } - // Emit variables of argument type. - for (size_t index = 0; index < argOperands.size(); index++) { - if (!typeMap.findArgumentMap(argOperands[index]).empty()) { - os << " " << typeMap.findArgumentMap(argOperands[index]) << " "; - if (!argOperandNames[index].empty()) { - os << argOperandNames[index] << ";\n"; - opArguments.push_back(argOperandNames[index]); - } else { - tmpStrings.emplace_back("arg" + std::to_string(index)); - const auto &arg = tmpStrings.back(); - os << arg << ";\n"; - opArguments.push_back(arg); - } - } else if (argOperands[index].startswith("AnyTypeOf")) { - llvm::StringRef operand = argOperands[index]; - auto start = operand.find('[') + 1; - auto end = operand.find(']'); - auto cur = start; - if (start == std::string::npos || end == std::string::npos) { - return; - } - llvm::StringRef type; - while (cur <= end) { - if (operand[cur] == ',' || cur == end) { - std::string str(operand, start, cur - start); - str.erase(0, str.find_first_not_of(" ")); - str.erase(str.find_last_not_of(" ") + 1); - if (typeMap.findArgumentMap(str).empty()) { - llvm::errs() << str << " in " << op->getOpName() - << " in arguments is not supported.\n"; - } - start = cur + 1; - type = typeMap.findArgumentMap(str); - } - cur++; - } - os << " " << type << " "; - if (!argOperandNames[index].empty()) { - os << argOperandNames[index] << ";\n"; - opArguments.push_back(argOperandNames[index]); - } else { - tmpStrings.emplace_back("arg" + std::to_string(index)); - const auto &arg = tmpStrings.back(); - os << arg << ";\n"; - opArguments.push_back(arg); - } - } else { - llvm::errs() << argOperands[index] << " in " << op->getOpName() - << " in arguments is not supported.\n"; - return; - } - } - // Emit the operation we want to create. - os << " mlir::Location location;\n"; - llvm::StringRef cppNameSpace( - module->getDialect()->getCppNamespace().data() + 1, - module->getDialect()->getCppNamespace().size() - 2); - os << " " - << "builder.create<" << cppNameSpace << "::" << op->getOpName() - << ">(location"; - if (opArguments.size()) - os << ", "; - for (size_t index = 0; index < opArguments.size(); index++) { - os << opArguments[index]; - if (index + 1 != opArguments.size()) - os << ", "; - } - os << ");\n"; - os << " }\n\n"; - } else if (index > 0) { - // Emit custom builder function. - index--; - // Emit the variables which are used to fill builder function. - llvm::SmallVector operands = - op->getBuilders()[index]->getDag()->getOperands(); - llvm::SmallVector operandNames = - op->getBuilders()[index]->getDag()->getOperandNames(); - llvm::SmallVector opArguments; - llvm::SmallVector tmpStrings; - os << " {\n"; - for (size_t index = 0; index < operands.size(); index++) { - if (!typeMap.findCppMap(operands[index]).empty()) - os << " " << typeMap.findCppMap(operands[index]); - else - os << " " << operands[index]; - if (!operandNames[index].empty()) { - os << " " << operandNames[index] << ";\n"; - opArguments.push_back(operandNames[index]); - } else { - tmpStrings.emplace_back("arg" + std::to_string(index)); - const auto &arg = tmpStrings.back(); - os << arg << ";\n"; - opArguments.push_back(arg); - } - } - // Emit the operation we want to create. - os << " mlir::Location location;\n"; - llvm::StringRef cppNameSpace( - module->getDialect()->getCppNamespace().data() + 1, - module->getDialect()->getCppNamespace().size() - 2); - os << " " - << "builder.create<" << cppNameSpace << "::" << op->getOpName() - << ">(location"; - if (!operandNames.empty()) { - os << ", "; - for (size_t index = 0; index < opArguments.size(); index++) { - os << opArguments[index]; - if (index + 1 != opArguments.size()) - os << ", "; - } - } - os << ");\n"; - os << " }\n\n"; - } -} - -llvm::StringRef TypeMap::findCppMap(llvm::StringRef key) { - if (cppMap.find(key) == cppMap.end()) - return llvm::StringRef(); - return cppMap[key]; -} - -llvm::StringRef TypeMap::findArgumentMap(llvm::StringRef key) { - if (argumentsMap.find(key) == argumentsMap.end()) - return llvm::StringRef(); - return argumentsMap[key]; -} - -llvm::StringRef TypeMap::findResultsMap(llvm::StringRef key) { - if (resultsMap.find(key) == resultsMap.end()) - return llvm::StringRef(); - return resultsMap[key]; -} diff --git a/frontend/FrontendGen/lib/CMakeLists.txt b/frontend/FrontendGen/lib/CMakeLists.txt index 90e9d4502..890795123 100644 --- a/frontend/FrontendGen/lib/CMakeLists.txt +++ b/frontend/FrontendGen/lib/CMakeLists.txt @@ -1,12 +1,50 @@ -include_directories(../include) -set(LLVM_LINK_COMPONENTS -support) - -add_llvm_component_library(LLVMfrontendgenlib -CGModule.cpp -Lexer.cpp -Parser.cpp -Sema.cpp -Diagnostics.cpp -LINK_COMPONENTS -support) +antlr_target(FegenLexer FegenLexer.g4 + PACKAGE fegen + LEXER + ) + +antlr_target(FegenParser FegenParser.g4 + PACKAGE fegen + DEPENDS_ANTLR FegenLexer + PARSER + LISTENER + VISITOR + COMPILE_FLAGS -lib + ${ANTLR_FegenLexer_OUTPUT_DIR} + ) + +include_directories(${ANTLR_FegenLexer_OUTPUT_DIR}) +set(ANTLR_FegenLexer_OUTPUT_DIR ${ANTLR_FegenLexer_OUTPUT_DIR} CACHE STRING "ANTLR_FegenLexer_OUTPUT_DIR") +include_directories(${ANTLR_FegenParser_OUTPUT_DIR}) +set(ANTLR_FegenParser_OUTPUT_DIR ${ANTLR_FegenParser_OUTPUT_DIR} CACHE STRING "ANTLR_FegenParser_OUTPUT_DIR") + +add_library(fegen_antlr_generated + ${ANTLR_FegenLexer_CXX_OUTPUTS} + ${ANTLR_FegenParser_CXX_OUTPUTS} +) +add_dependencies(fegen_antlr_generated antlr4_runtime) + +include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../include") +add_library(FegenSupport + FegenManager.cpp + Scope.cpp +) +add_dependencies(FegenSupport fegen_antlr_generated) + +llvm_map_components_to_libnames(llvm_libs support) + +target_link_libraries(FegenSupport + PRIVATE + ${llvm_libs} +) + +add_library(fegenVisitor + FegenVisitor.cpp +) + +target_link_libraries(fegenVisitor + PUBLIC + fegen_antlr_generated + antlr4_static + FegenSupport +) \ No newline at end of file diff --git a/frontend/FrontendGen/lib/Diagnostics.cpp b/frontend/FrontendGen/lib/Diagnostics.cpp deleted file mode 100644 index d75986709..000000000 --- a/frontend/FrontendGen/lib/Diagnostics.cpp +++ /dev/null @@ -1,43 +0,0 @@ -//====- Diagnostics.cpp -------------------------------------------------===// -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// - -#include "Diagnostics.h" -#include "llvm/Support/SourceMgr.h" - -using namespace frontendgen; -namespace { - -/// Storage the message of the diagnostic. -const char *diagnosticText[] = { -#define DIAG(ID, Level, Msg) Msg, -#include "Diagnostics.def" -}; - -/// Storage the kind of the diagnostic. -llvm::SourceMgr::DiagKind diagnosticKind[] = { -#define DIAG(ID, Level, Msg) llvm::SourceMgr::DK_##Level, -#include "Diagnostics.def" -}; -} // namespace - -/// Get the message of the diagnostic. -const char *DiagnosticEngine::getDiagnosticText(unsigned diagID) { - return diagnosticText[diagID]; -} -/// Get the kind of the diagnostic. -llvm::SourceMgr::DiagKind DiagnosticEngine::getDiagnosticKind(unsigned DiagID) { - return diagnosticKind[DiagID]; -} diff --git a/frontend/FrontendGen/lib/FegenLexer.g4 b/frontend/FrontendGen/lib/FegenLexer.g4 new file mode 100644 index 000000000..6b5e0aa2a --- /dev/null +++ b/frontend/FrontendGen/lib/FegenLexer.g4 @@ -0,0 +1,223 @@ +lexer grammar FegenLexer; + +fragment Schar: ~ ["\\\r\n]; + +fragment NONDIGIT: [a-zA-Z_]; + +fragment UPPERCASE: [A-Z]; + +fragment LOWERCASE: [a-z]; + +fragment ALLCASE: [a-zA-Z0-9_]; + +fragment NOZERODIGIT: [1-9]; + +fragment DIGIT: [0-9]; + +fragment SQuoteLiteral + : '\'' (('\\' ([btnfr"'\\] | . |EOF))|( ~ ['\r\n\\]))* '\'' + ; + +// literal + +UnsignedInt: NOZERODIGIT DIGIT* | '0'; + +ScienceReal : (Plus | Minus)? UnsignedInt Dot UnsignedInt ( 'E' (Plus | Minus)? UnsignedInt )?; + +ConstBoolean: 'true' | 'false'; + +// key words + +FEGEN: 'fegen'; + +INPUTS: 'inputs'; + +RETURNS: 'returns'; + +ACTIONS: 'actions'; + +IR: 'ir'; + +OPERAND_VALUE: 'operandValue'; + +ATTRIBUTE_VALUE: 'attributeValue'; + +CPP_VALUE: 'cppValue'; + +OPERATION: 'operation'; + +FUNCTION: 'function'; + +TYPEDEF: 'typedef'; + +OPDEF: 'opdef'; + +ARGUMENTS: 'arguments'; + +RESULTS: 'results'; + +BODY: 'body'; + +EMPTY: 'null'; + +PARAMETERS: 'parameters'; + +ASSEMBLY_FORMAT: 'assemblyFormat'; + + +// types +TYPE: 'Type'; + +TYPETEMPLATE: 'TypeTemplate'; + +BOOL: 'bool'; + +INT: 'int'; + +FLOAT: 'float'; + +DOUBLE: 'double'; + +// F64TENSOR: 'F64Tensor'; + +// F64VECTOR: 'F64Vector'; + +CHAR: 'char'; + +STRING: 'string'; + +LIST: 'list'; + +ANY: 'any'; + +OPTIONAL: 'optional'; + +INTEGER: 'Integer'; + +FLOATPOINT: 'FloatPoint'; + +TENSOR: 'Tensor'; + +VECTOR: 'Vector'; + +CPP: 'cpp'; + +OPERAND: 'operand'; + +ATTRIBUTE: 'attribute'; + +// stmt + +IF: 'if'; + +ELSE: 'else'; + +FOR: 'for'; + +IN: 'in'; + +WHILE: 'while'; + +RETURN: 'return'; + +// identifiers + +LexerRuleName: UPPERCASE (NONDIGIT | DIGIT)*; + +ParserRuleName: LOWERCASE (NONDIGIT | DIGIT)*; + +// literal + +StringLiteral + : SQuoteLiteral + ; + + +// marks + +AND: '&&'; + +Logic_OR: '||'; + +EQUAL: '=='; + +NOT_EQUAL: '!='; + +Less: '<'; + +LessEqual: '<='; + +Greater: '>'; + +GreaterEqual: '>='; + +Comma: ','; + +Semi: ';'; + +LeftParen: '('; + +RightParen: ')'; + +LeftBracket: '['; + +RightBracket: ']'; + +LeftBrace: '{'; + +RightBrace: '}'; + +Dot: '.'; + +Colon: ':'; + +OR: '|'; + +QuestionMark: '?'; + +Star: '*'; + +Div: '/'; + +Plus: '+'; + +Minus: '-'; + +Assign: '='; + +Dollar: '$'; + +StarStar: '**'; + +MOD: '%'; + +Arror: '->'; + +Underline: '_'; + +Tilde: '~'; + +Exclamation: '!'; + +Range: '..'; + +BeginInclude: '@header' LeftBrace -> pushMode (TargetLanguageAction); + +Whitespace: [ \t]+ -> skip; + +Newline: ('\r' '\n'? | '\n') -> skip; + +BlockComment: '/*' .*? '*/' -> skip; + +LineComment: '//' ~ [\r\n]* -> skip; + +mode TargetLanguageAction; + +EndInclude: RightBrace -> popMode; + +INCLUDE_CONTENT + : . + | '\n' + | ' ' + ; + diff --git a/frontend/FrontendGen/lib/FegenManager.cpp b/frontend/FrontendGen/lib/FegenManager.cpp new file mode 100644 index 000000000..04ede3170 --- /dev/null +++ b/frontend/FrontendGen/lib/FegenManager.cpp @@ -0,0 +1,2121 @@ +#include "FegenManager.h" +#include "FegenParser.h" +#include "FegenParserBaseVisitor.h" +#include "Scope.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +fegen::Function::Function(std::string name, + std::vector &&inputTypeList, + TypePtr returnType) + : name(name), inputTypeList(inputTypeList), returnType(returnType) {} + +fegen::Function *fegen::Function::get(std::string name, + std::vector inputTypeList, + TypePtr returnType) { + return new fegen::Function(name, std::move(inputTypeList), returnType); +} +std::string fegen::Function::getName() { return this->name; } + +std::vector &fegen::Function::getInputTypeList() { + return this->inputTypeList; +} + +fegen::Value *fegen::Function::getInputTypeList(size_t i) { + return this->inputTypeList[i]; +} + +fegen::TypePtr fegen::Function::getReturnType() { return this->returnType; } + +fegen::Operation::Operation(std::string dialectName, std::string operationName, + std::vector &&arguments, + std::vector &&results, + fegen::FegenParser::BodySpecContext *ctx) + : dialectName(dialectName), arguments(arguments), results(results), + ctx(ctx) {} + +void fegen::Operation::setOpName(std::string name) { + this->operationName = name; +} +std::string fegen::Operation::getOpName() { return this->operationName; } + +std::vector &fegen::Operation::getArguments() { + return this->arguments; +} + +fegen::Value *fegen::Operation::getArguments(size_t i) { + return this->arguments[i]; +} + +std::vector &fegen::Operation::getResults() { + return this->results; +} + +fegen::Value *fegen::Operation::getResults(size_t i) { + return this->results[i]; +} + +fegen::Operation *fegen::Operation::get(std::string operationName, + std::vector arguments, + std::vector results, + FegenParser::BodySpecContext *ctx) { + return new fegen::Operation(fegen::Manager::getManager().moduleName, + operationName, std::move(arguments), + std::move(results), ctx); +} + +// class FegenType + +fegen::Type::Type(TypeKind kind, std::string name, TypeDefination *tyDef, + int typeLevel, bool isConstType) + : kind(kind), typeName(name), typeDefine(tyDef), typeLevel(typeLevel), + isConstType(isConstType) {} + +fegen::Type::TypeKind fegen::Type::getTypeKind() { return this->kind; } + +void fegen::Type::setTypeKind(fegen::Type::TypeKind kind) { this->kind = kind; } + +fegen::TypeDefination *fegen::Type::getTypeDefination() { + return this->typeDefine; +} + +void fegen::Type::setTypeDefination(fegen::TypeDefination *tyDef) { + this->typeDefine = tyDef; +} + +std::string fegen::Type::getTypeName() { return this->typeName; } + +int fegen::Type::getTypeLevel() { return this->typeLevel; } + +bool fegen::Type::isConstant() { return this->isConstType; } + +bool fegen::Type::isSameType(fegen::TypePtr type1, fegen::TypePtr type2) { + if (type1->getTypeName() == type2->getTypeName()) { + std::cout << "1" << std::endl; + return true; + } + + else + return false; +} + +std::string fegen::Type::toStringForTypedef() { + std::cerr << this->getTypeName() << std::endl; + assert(FEGEN_NOT_IMPLEMENTED_ERROR); +} + +std::string fegen::Type::toStringForOpdef() { + assert(FEGEN_NOT_IMPLEMENTED_ERROR); +} + +std::string fegen::Type::toStringForCppKind() { + assert(FEGEN_NOT_IMPLEMENTED_ERROR); +} + +fegen::TypePtr fegen::Type::getPlaceHolder() { + return std::make_shared(); +} + +fegen::TypePtr fegen::Type::getMetaType() { + return std::make_shared(); +} + +fegen::TypePtr fegen::Type::getMetaTemplateType() { + return std::make_shared(); +} + +fegen::TypePtr fegen::Type::getInt32Type() { + return std::make_shared(RightValue::getInteger(32)); +} + +fegen::TypePtr fegen::Type::getFloatType() { + return std::make_shared(RightValue::getInteger(32)); +} + +fegen::TypePtr fegen::Type::getDoubleType() { + return std::make_shared(RightValue::getInteger(64)); +} + +fegen::TypePtr fegen::Type::getBoolType() { + return std::make_shared(RightValue::getInteger(1)); +} + +fegen::TypePtr fegen::Type::getIntegerType(fegen::RightValue size) { + return std::make_shared(size); +} + +fegen::TypePtr fegen::Type::getFloatPointType(fegen::RightValue size) { + return std::make_shared(size); +} + +fegen::TypePtr fegen::Type::getStringType() { + return std::make_shared(); +} + +fegen::TypePtr fegen::Type::getListType(fegen::TypePtr elementType) { + assert(elementType->typeLevel == 3); + return std::make_shared( + fegen::RightValue::getTypeRightValue(elementType)); +} + +fegen::TypePtr fegen::Type::getListType(RightValue elementType) { + auto ty = std::any_cast(elementType.getContent()); + return Type::getListType(ty); +} + +fegen::TypePtr fegen::Type::getVectorType(fegen::TypePtr elementType, + fegen::RightValue size) { + assert(elementType->typeLevel == 3); + return std::make_shared( + fegen::RightValue::getTypeRightValue(elementType), size); +} + +fegen::TypePtr fegen::Type::getVectorType(RightValue elementType, + RightValue size) { + auto ty = std::any_cast(elementType.getContent()); + return Type::getVectorType(ty, size); +} + +fegen::TypePtr fegen::Type::getTensorType(fegen::TypePtr elementType) { + assert(elementType->typeLevel == 3); + return std::make_shared( + fegen::RightValue::getTypeRightValue(elementType)); +} + +fegen::TypePtr fegen::Type::getTensorType(RightValue elementType) { + auto ty = std::any_cast(elementType.getContent()); + return Type::getTensorType(ty); +} + +fegen::TypePtr fegen::Type::getOptionalType(fegen::TypePtr elementType) { + assert(elementType->typeLevel == 3); + return std::make_shared( + RightValue::getTypeRightValue(elementType)); +} + +fegen::TypePtr fegen::Type::getOptionalType(RightValue elementType) { + auto ty = std::any_cast(elementType.getContent()); + return Type::getOptionalType(ty); +} + +fegen::TypePtr fegen::Type::getAnyType(fegen::RightValue elementTypes) { + return std::make_shared(elementTypes); +} + +fegen::TypePtr +fegen::Type::getCustomeType(std::vector params, + fegen::TypeDefination *tydef) { + return std::make_shared(params, tydef); +} + +// Integer +fegen::TypePtr fegen::Type::getIntegerTemplate() { + return std::make_shared(); +} + +// FloatPoint +fegen::TypePtr fegen::Type::getFloatPointTemplate() { + return std::make_shared(); +} + +// string +fegen::TypePtr fegen::Type::getStringTemplate() { + return std::make_shared(); +} + +// List +fegen::TypePtr fegen::Type::getListTemplate(TypePtr elementType) { + assert(elementType->typeLevel == 2 || elementType->typeLevel == 1); + return std::make_shared( + fegen::RightValue::getTypeRightValue(elementType)); +} + +fegen::TypePtr fegen::Type::getListTemplate(RightValue elementType) { + auto ty = std::any_cast(elementType.getContent()); + return Type::getListTemplate(ty); +} + +// Vector +fegen::TypePtr fegen::Type::getVectorTemplate() { + return std::make_shared(); +} + +// Tensor +fegen::TypePtr fegen::Type::getTensorTemplate() { + return std::make_shared(); +} + +// Optional +fegen::TypePtr fegen::Type::getOptionalTemplate(TypePtr elementType) { + assert(elementType->typeLevel == 2); + return std::make_shared( + fegen::RightValue::getTypeRightValue(elementType)); +} +fegen::TypePtr fegen::Type::getOptionalTemplate(RightValue elementType) { + auto ty = std::any_cast(elementType.getContent()); + return Type::getOptionalTemplate(ty); +} + +// Any<[elementType1, elementType2, ...]> +fegen::TypePtr fegen::Type::getAnyTemplate(RightValue elementTypes) { + return std::make_shared(elementTypes); +} + +fegen::TypePtr fegen::Type::getCustomeTemplate(TypeDefination *tydef) { + assert(tydef->isCustome()); + return std::make_shared(tydef); +} + +/// @brief get name of Type Instance by jointsing template name and parameters, +/// for example: Integer + 32 --> Integer<32> +/// @return joint name +std::string jointTypeName(std::string templateName, + std::vector parameters) { + if (parameters.empty()) { + return templateName; + } + std::string res = templateName; + res.append("<"); + size_t count = parameters.size(); + auto firstParamStr = parameters[0].toString(); + res.append(firstParamStr); + for (size_t i = 1; i <= count - 1; i++) { + auto paramStr = parameters[i].toString(); + res.append(", "); + res.append(paramStr); + } + res.append(">"); + return res; +} + +// class PlaceHolderType +fegen::PlaceHolderType::PlaceHolderType() + : Type(fegen::Type::TypeKind::CPP, FEGEN_PLACEHOLDER, + fegen::Manager::getManager().getTypeDefination(FEGEN_PLACEHOLDER), 0, + true) {} + +// class MetaType +fegen::MetaType::MetaType() + : Type(fegen::Type::TypeKind::CPP, FEGEN_TYPE, + fegen::Manager::getManager().getTypeDefination(FEGEN_TYPE), 2, + true) {} + +std::string fegen::MetaType::toStringForTypedef() { return "\"Type\""; } + +// class MetaTemplate +fegen::MetaTemplate::MetaTemplate() + : Type(fegen::Type::TypeKind::CPP, FEGEN_TYPETEMPLATE, + fegen::Manager::getManager().getTypeDefination(FEGEN_TYPETEMPLATE), + 1, true) {} + +// class IntegerType + +fegen::IntegerType::IntegerType(RightValue size, TypeDefination *tyDef) + : Type(fegen::Type::TypeKind::CPP, jointTypeName(FEGEN_INTEGER, {size}), + tyDef, 3, size.isConstant()), + size(size) {} + +fegen::IntegerType::IntegerType(fegen::RightValue size) + : Type(fegen::Type::TypeKind::CPP, jointTypeName(FEGEN_INTEGER, {size}), + fegen::Manager::getManager().getTypeDefination(FEGEN_INTEGER), 3, + size.isConstant()), + size(size) {} + +fegen::largestInt fegen::IntegerType::getSize() { + assert(this->size.getLiteralKind() == RightValue::LiteralKind::INT); + return std::any_cast(this->size.getContent()); +} + +std::string fegen::IntegerType::toStringForTypedef() { + auto content = std::any_cast(this->size.getContent()); + if (content == 32) { + return "\"int\""; + } else if (content == 1) { + return "\"bool\""; + } else if (content == 64) { + return "\"long\""; + } else if (content == 16) { + return "\"short\""; + } else { + std::cerr << "unsupport type: " << this->getTypeName() << std::endl; + assert(false); + } +} + +std::string fegen::IntegerType::toStringForOpdef() { + auto content = std::any_cast(this->size.getContent()); + if (content == 32) { + return "I32"; + } else if (content == 64) { + return "I64"; + } else if (content == 16) { + return "I16"; + } else { + std::cerr << "unsupport type: " << this->getTypeName() << std::endl; + assert(false); + } +} + +std::string fegen::IntegerType::toStringForCppKind() { + auto content = std::any_cast(this->size.getContent()); + if (content == 32) { + return "int"; + } + if (content == 64) { + return "long"; + } else if (content == 16) { + return "short"; + } else { + std::cerr << "unsupport type: " << this->getTypeName() << std::endl; + assert(false); + } +} + +// class FloatPointType +fegen::FloatPointType::FloatPointType(fegen::RightValue size) + : Type(fegen::Type::TypeKind::CPP, jointTypeName(FEGEN_FLOATPOINT, {size}), + fegen::Manager::getManager().getTypeDefination(FEGEN_FLOATPOINT), 3, + size.isConstant()), + size(size) {} + +fegen::largestInt fegen::FloatPointType::getSize() { + assert(this->size.getLiteralKind() == RightValue::LiteralKind::INT); + return std::any_cast(this->size.getContent()); +} + +std::string fegen::FloatPointType::toStringForTypedef() { + auto content = std::any_cast(this->size.getContent()); + if (content == 32) { + return "\"float\""; + } else if (content == 64) { + return "\"double\""; + } else { + std::cerr << "unsupport type: " << this->getTypeName() << std::endl; + assert(false); + } +} + +std::string fegen::FloatPointType::toStringForOpdef() { + auto content = std::any_cast(this->size.getContent()); + switch (this->getTypeKind()) { + case Type::TypeKind::ATTRIBUTE: { + if (content == 32) { + return "F32ElementsAttr"; + } else if (content == 64) { + return "F64ElementsAttr"; + } + break; + } + case Type::TypeKind::OPERAND: { + if (content == 32) { + return "F32"; + } else if (content == 64) { + return "F64"; + } + break; + } + default: { + assert(false); + } + } + assert(false); +} + +std::string fegen::FloatPointType::toStringForCppKind() { + auto content = std::any_cast(this->size.getContent()); + if (content == 32) { + return "float"; + } else if (content == 64) { + return "double"; + } else if (content == 128) { + return "long double"; + } else { + std::cerr << "unsupport type: " << this->getTypeName() << std::endl; + assert(false); + } +} + +// class StringType +fegen::StringType::StringType() + : Type(fegen::Type::TypeKind::CPP, FEGEN_STRING, + fegen::Manager::getManager().getTypeDefination(FEGEN_STRING), 3, + true) {} + +std::string fegen::StringType::toStringForCppKind() { return "string"; } + +// class ListType +fegen::ListType::ListType(fegen::RightValue elementType) + : Type(fegen::Type::TypeKind::CPP, jointTypeName(FEGEN_LIST, {elementType}), + fegen::Manager::getManager().getTypeDefination(FEGEN_LIST), 3, + elementType.isConstant()), + elementType(elementType) {} + +std::string fegen::ListType::toStringForTypedef() { + std::string res = "ArrayRefParameter<"; + res.append(this->elementType.toStringForTypedef()); + res.append(">"); + return res; +} + +std::string fegen::ListType::toStringForOpdef() { + std::string res = "Variadic<"; + res.append(this->elementType.toStringForOpdef()); + res.append(">"); + return res; +} + +std::string fegen::ListType::toStringForCppKind() { + std::string res = "std::vector<"; + res.append(this->elementType.toStringForCppKind()); + res.append(">"); + return res; +} + +// class VectorType +fegen::VectorType::VectorType(RightValue elementType, RightValue size) + : Type(fegen::Type::TypeKind::CPP, + jointTypeName(FEGEN_VECTOR, {elementType, size}), + fegen::Manager::getManager().getTypeDefination(FEGEN_VECTOR), 3, + (elementType.isConstant() && size.isConstant())), + elementType(elementType), size(size) {} + +// class TensorType +fegen::TensorType::TensorType(RightValue elementType) + : Type(fegen::Type::TypeKind::CPP, + jointTypeName(FEGEN_TENSOR, {elementType}), + fegen::Manager::getManager().getTypeDefination(FEGEN_TENSOR), 3, + elementType.isConstant()), + elementType(elementType) {} + +std::string fegen::TensorType::toStringForOpdef() { + auto elemTy = std::any_cast(this->elementType.getContent()); + auto elemTyName = elemTy->getTypeDefination()->getName(); + if (elemTyName != FEGEN_INTEGER && elemTyName != FEGEN_FLOATPOINT) { + assert(false); + } + if (elemTyName == FEGEN_INTEGER) { + auto intTy = std::dynamic_pointer_cast(elemTy); + auto size = intTy->getSize(); + switch (size) { + case 1: + return "I1Tensor"; + case 8: + return "I8Tensor"; + case 16: + return "I16Tensor"; + case 32: + return "I32Tensor"; + case 64: + return "I64Tensor"; + default: { + std::cerr << "unsupprot type: " << this->getTypeName() << std::endl; + exit(0); + } + } + } else { + auto floatTy = std::dynamic_pointer_cast(elemTy); + auto size = floatTy->getSize(); + switch (size) { + case 16: + return "F16Tensor"; + case 32: + return "F32Tensor"; + case 64: + return "F64Tensor"; + default: { + std::cerr << "unsupprot type: " << this->getTypeName() << std::endl; + exit(0); + } + } + } +} + +// class OptionalType +fegen::OptionalType::OptionalType(RightValue elementType) + : Type(fegen::Type::TypeKind::CPP, + jointTypeName(FEGEN_OPTINAL, {elementType}), + fegen::Manager::getManager().getTypeDefination(FEGEN_OPTINAL), 3, + elementType.isConstant()), + elementType(elementType) {} + +// class AnyType + +inline int getTypeLevelOfListType(fegen::RightValue &elementTypes) { + auto listContent = std::any_cast>( + elementTypes.getContent()); + fegen::TypePtr ty = + std::any_cast(listContent[0]->getContent()); + return ty->getTypeLevel(); +} + +fegen::AnyType::AnyType(RightValue elementTypes) + : Type(fegen::Type::TypeKind::CPP, jointTypeName(FEGEN_ANY, {elementTypes}), + fegen::Manager::getManager().getTypeDefination(FEGEN_ANY), 3, + elementTypes.isConstant()), + elementTypes(elementTypes) {} + +// class CustomeType +inline bool isAllConstant(std::vector ¶ms) { + for (auto v : params) { + if (!v.isConstant()) { + return false; + } + } + return true; +} + +fegen::CustomeType::CustomeType(std::vector params, + TypeDefination *tydef) + : Type(fegen::Type::TypeKind::CPP, jointTypeName(FEGEN_ANY, params), tydef, + 3, isAllConstant(params)), + params(params) {} + +// class TemplateType +fegen::TemplateType::TemplateType(TypeDefination *tydef) + : Type(fegen::Type::TypeKind::CPP, tydef->getName(), tydef, 2, true) {} + +// class IntegerTemplateType +fegen::IntegerTemplateType::IntegerTemplateType() + : TemplateType( + fegen::Manager::getManager().getTypeDefination(FEGEN_INTEGER)) {} + +fegen::TypePtr +fegen::IntegerTemplateType::instantiate(std::vector params) { + assert(params.size() == 1); + return Type::getIntegerType(params[0]); +} + +std::string fegen::IntegerTemplateType::toStringForTypedef() { + return "Builtin_IntegerAttr"; +} + +std::string fegen::IntegerTemplateType::toStringForOpdef() { + return "Builtin_Integer"; +} + +// class FloatPointTemplateType +fegen::FloatPointTemplateType::FloatPointTemplateType() + : TemplateType( + fegen::Manager::getManager().getTypeDefination(FEGEN_FLOATPOINT)) {} + +fegen::TypePtr +fegen::FloatPointTemplateType::instantiate(std::vector params) { + assert(params.size() == 1); + return Type::getFloatPointType(params[0]); +} + +std::string fegen::FloatPointTemplateType::toStringForTypedef() { + return "Builtin_FloatAttr"; +} + +// class StringTemplateType +fegen::StringTemplateType::StringTemplateType() + : TemplateType( + fegen::Manager::getManager().getTypeDefination(FEGEN_STRING)) {} + +fegen::TypePtr +fegen::StringTemplateType::instantiate(std::vector params) { + assert(params.size() == 0); + return Type::getStringType(); +} + +std::string fegen::StringTemplateType::toStringForTypedef() { + return "Builtin_StringAttr"; +} + +// class ListTemplateType +fegen::ListTemplateType::ListTemplateType(fegen::RightValue elementType) + : TemplateType(fegen::Manager::getManager().getTypeDefination(FEGEN_LIST)), + elementType(elementType) {} + +fegen::TypePtr +fegen::ListTemplateType::instantiate(std::vector params) { + assert(params.size() == 1); + return Type::getListType(params[0]); +} + +std::string fegen::ListTemplateType::toStringForTypedef() { + std::string res = "ArrayRefParameter<"; + res.append(this->elementType.toStringForTypedef()); + res.append(">"); + return res; +} + +std::string fegen::ListTemplateType::toStringForOpdef() { + std::string res = "Variadic<"; + res.append(this->elementType.toStringForOpdef()); + res.append(">"); + return res; +} + +// class VectorTemplateType +fegen::VectorTemplateType::VectorTemplateType() + : TemplateType( + fegen::Manager::getManager().getTypeDefination(FEGEN_VECTOR)) {} + +fegen::TypePtr +fegen::VectorTemplateType::instantiate(std::vector params) { + assert(params.size() == 2); + return Type::getVectorType(params[0], params[1]); +} + +// class TensorTemplateType +fegen::TensorTemplateType::TensorTemplateType() + : TemplateType( + fegen::Manager::getManager().getTypeDefination(FEGEN_TENSOR)) {} + +fegen::TypePtr +fegen::TensorTemplateType::instantiate(std::vector params) { + assert(params.size() == 1); + return Type::getTensorType(params[0]); +} + +// class OptionalTemplateType +fegen::OptionalTemplateType::OptionalTemplateType(RightValue elementType) + : TemplateType( + fegen::Manager::getManager().getTypeDefination(FEGEN_OPTINAL)), + elementType(elementType) {} + +fegen::TypePtr +fegen::OptionalTemplateType::instantiate(std::vector params) { + assert(params.size() == 1); + return Type::getOptionalType(params[0]); +} + +// class AnyTemplateType +fegen::AnyTemplateType::AnyTemplateType(RightValue elementTypes) + : TemplateType(fegen::Manager::getManager().getTypeDefination(FEGEN_ANY)), + elementTypes(elementTypes) {} + +fegen::TypePtr +fegen::AnyTemplateType::instantiate(std::vector params) { + assert(params.size() == 1); + return Type::getAnyType(params[0]); +} + +// class CustomeTemplateType +fegen::CustomeTemplateType::CustomeTemplateType(TypeDefination *tydef) + : TemplateType(tydef) {} + +fegen::TypePtr +fegen::CustomeTemplateType::instantiate(std::vector params) { + return Type::getCustomeType(params, this->getTypeDefination()); +} + +std::string fegen::CustomeTemplateType::toStringForTypedef() { + return this->getTypeDefination()->getName(); +} + +std::string fegen::CustomeTemplateType::toStringForOpdef() { + return this->getTypeDefination()->getName(); +} + +// class FegenTypeDefination +fegen::TypeDefination::TypeDefination( + std::string dialectName, std::string name, + std::vector parameters, + FegenParser::TypeDefinationDeclContext *ctx, bool ifCustome) + : dialectName(std::move(dialectName)), name(std::move(name)), + parameters(std::move(parameters)), ctx(ctx), ifCustome(ifCustome) {} + +fegen::TypeDefination * +fegen::TypeDefination::get(std::string dialectName, std::string name, + std::vector parameters, + FegenParser::TypeDefinationDeclContext *ctx, + bool ifCustome) { + return new fegen::TypeDefination(std::move(dialectName), std::move(name), + std::move(parameters), ctx, ifCustome); +} + +std::string fegen::TypeDefination::getDialectName() { + return this->dialectName; +} + +void fegen::TypeDefination::setDialectName(std::string name) { + this->dialectName = name; +} + +std::string fegen::TypeDefination::getName() { return this->name; } + +std::string fegen::TypeDefination::getMnemonic() { + if (this->mnemonic.empty()) { + this->mnemonic = this->name; + std::transform(this->mnemonic.begin(), this->mnemonic.end(), + this->mnemonic.begin(), ::tolower); + } + return this->mnemonic; +} + +void fegen::TypeDefination::setName(std::string name) { this->name = name; } + +const std::vector &fegen::TypeDefination::getParameters() { + return this->parameters; +} + +fegen::FegenParser::TypeDefinationDeclContext *fegen::TypeDefination::getCtx() { + return this->ctx; +} + +void fegen::TypeDefination::setCtx( + FegenParser::TypeDefinationDeclContext *ctx) { + this->ctx = ctx; +} + +bool fegen::TypeDefination::isCustome() { return this->ifCustome; } + +// class Expression + +fegen::RightValue::Expression::Expression(bool ifTerminal, LiteralKind kind, + bool isConstexpr) + : ifTerminal(ifTerminal), kind(kind), ifConstexpr(isConstexpr) {} + +bool fegen::RightValue::Expression::isTerminal() { return this->ifTerminal; } + +fegen::RightValue::LiteralKind fegen::RightValue::Expression::getKind() { + return this->kind; +} + +bool fegen::RightValue::Expression::isConstexpr() { return this->ifConstexpr; } + +std::shared_ptr +fegen::RightValue::Expression::getPlaceHolder() { + return std::make_shared(); +} + +std::shared_ptr +fegen::RightValue::Expression::getInteger(largestInt content, size_t size) { + return std::make_shared(content, size); +} + +std::shared_ptr +fegen::RightValue::Expression::getFloatPoint(long double content, size_t size) { + return std::make_shared(content, size); +} + +std::shared_ptr +fegen::RightValue::Expression::getString(std::string content) { + return std::make_shared(content); +} + +std::shared_ptr +fegen::RightValue::Expression::getTypeRightValue(fegen::TypePtr content) { + return std::make_shared(content); +} + +std::shared_ptr +fegen::RightValue::Expression::getList( + std::vector> &content) { + return std::make_shared(content); +} + +std::shared_ptr +fegen::RightValue::Expression::getLeftValue(fegen::Value *content) { + return std::make_shared(content); +} + +std::shared_ptr +fegen::RightValue::Expression::binaryOperation( + std::shared_ptr lhs, + std::shared_ptr rhs, FegenOperator op) { + TypePtr resTy = fegen::inferenceType({lhs, rhs}, op); + return std::make_shared( + op, + std::vector>{lhs, rhs}); +} + +std::shared_ptr +fegen::RightValue::Expression::unaryOperation( + std::shared_ptr v, FegenOperator op) { + TypePtr resTy = fegen::inferenceType({v}, op); + return std::make_shared( + op, std::vector>{v}); +} + +// class ExpressionNode + +fegen::RightValue::ExpressionNode::ExpressionNode(LiteralKind kind, + bool ifConstexpr) + : Expression(false, kind, ifConstexpr) {} + +std::string fegen::RightValue::ExpressionNode::toString() { + assert(FEGEN_NOT_IMPLEMENTED_ERROR); +} + +std::string fegen::RightValue::ExpressionNode::toStringForTypedef() { + assert(FEGEN_NOT_IMPLEMENTED_ERROR); +} + +std::string fegen::RightValue::ExpressionNode::toStringForOpdef() { + assert(FEGEN_NOT_IMPLEMENTED_ERROR); +} + +std::string fegen::RightValue::ExpressionNode::toStringForCppKind() { + assert(FEGEN_NOT_IMPLEMENTED_ERROR); +} + +fegen::TypePtr fegen::RightValue::ExpressionNode::getType() { + assert(FEGEN_NOT_IMPLEMENTED_ERROR); +} + +// class FunctionCall +inline bool isFuncParamsAllConstant( + std::vector> ¶ms) { + for (auto param : params) { + if (!param->isConstexpr()) { + return false; + } + } + return true; +} + +// TODO: invoke methods of FegenFunction +fegen::RightValue::FunctionCall::FunctionCall( + fegen::Function *func, + std::vector> params) + : ExpressionNode(fegen::RightValue::LiteralKind::FUNC_CALL, + isFuncParamsAllConstant(params)), + func(func), params(std::move(params)) {} + +std::string fegen::RightValue::FunctionCall::toString() { + return "FunctionCall::toString"; +} + +std::string fegen::RightValue::FunctionCall::toStringForTypedef() { + return "FunctionCall::toStringForTypedef"; +} + +std::string fegen::RightValue::FunctionCall::toStringForOpdef() { + return "FunctionCall::toStringForOpdef"; +} + +std::string fegen::RightValue::FunctionCall::toStringForCppKind() { + return "FunctionCall::toStringForCppKind"; +} + +std::any fegen::RightValue::FunctionCall::getContent() { return this; } + +fegen::TypePtr fegen::RightValue::FunctionCall::getType() { + return this->func->getReturnType(); +} + +// class OperationCall +fegen::RightValue::OperationCall::OperationCall( + fegen::Operation *op, + std::vector> params) + : ExpressionNode(fegen::RightValue::LiteralKind::OPERATION_CALL, + isFuncParamsAllConstant(params)), + op(op), params(std::move(params)) {} + +std::string fegen::RightValue::OperationCall::toString() { + return "OperationCall::toString"; +} + +std::string fegen::RightValue::OperationCall::toStringForTypedef() { + return "OperationCall::toStringForTypedef"; +} + +std::string fegen::RightValue::OperationCall::toStringForOpdef() { + return "OperationCall::toStringForOpdef"; +} + +std::string fegen::RightValue::OperationCall::toStringForCppKind() { + return "OperationCall::toStringForCppKind"; +} + +std::any fegen::RightValue::OperationCall::getContent() { return this; } + +fegen::TypePtr fegen::RightValue::OperationCall::getType() { + assert(FEGEN_NOT_IMPLEMENTED_ERROR); +} + +// class OperatorCall +fegen::RightValue::OperatorCall::OperatorCall( + fegen::FegenOperator op, + std::vector> params) + : ExpressionNode(fegen::RightValue::LiteralKind::OPERATION_CALL, + isFuncParamsAllConstant(params)), + op(op), params(std::move(params)) {} + +std::string fegen::RightValue::OperatorCall::toString() { + return "OperatorCall::toString"; +} + +std::string fegen::RightValue::OperatorCall::toStringForTypedef() { + return "OperatorCall::toStringForTypedef"; +} + +std::string fegen::RightValue::OperatorCall::toStringForOpdef() { + return "OperatorCall::toStringForOpdef"; +} + +inline bool isBinaryOperator(fegen::FegenOperator &op) { + switch (op) { + case fegen::FegenOperator::NEG: + case fegen::FegenOperator::NOT: + return false; + default: + return true; + } +} + +std::unordered_map + fegen::RightValue::OperatorCall::cppOperatorMap = { + {fegen::FegenOperator::OR, "||"}, + {fegen::FegenOperator::AND, "&&"}, + {fegen::FegenOperator::EQUAL, "=="}, + {fegen::FegenOperator::NOT_EQUAL, "!="}, + {fegen::FegenOperator::LESS, "<"}, + {fegen::FegenOperator::LESS_EQUAL, "<="}, + {fegen::FegenOperator::GREATER, ">"}, + {fegen::FegenOperator::GREATER_EQUAL, ">="}, + {fegen::FegenOperator::ADD, "+"}, + {fegen::FegenOperator::SUB, "-"}, + {fegen::FegenOperator::MUL, "*"}, + {fegen::FegenOperator::DIV, "/"}, + {fegen::FegenOperator::MOD, "%"}, + {fegen::FegenOperator::POWER, "pow"}, + {fegen::FegenOperator::NEG, "-"}, + {fegen::FegenOperator::NOT, "!"}}; + +std::string fegen::RightValue::OperatorCall::toStringForCppKind() { + std::string res; + if (isBinaryOperator(this->op)) { + res.append(this->params[0]->toStringForCppKind()); + res.append(" "); + res.append(OperatorCall::cppOperatorMap[this->op]); + res.append(" "); + res.append(this->params[1]->toStringForCppKind()); + } else { + res.append(OperatorCall::cppOperatorMap[this->op]); + res.append(this->params[0]->toStringForCppKind()); + } + return res; +} + +std::any fegen::RightValue::OperatorCall::getContent() { return this; } + +fegen::TypePtr fegen::RightValue::OperatorCall::getType() { + return inferenceType(this->params, this->op); +} + +// class ExpressionTerminal +fegen::RightValue::ExpressionTerminal::ExpressionTerminal( + fegen::RightValue::LiteralKind kind, bool ifConstexpr) + : Expression(true, kind, ifConstexpr) {} + +std::string fegen::RightValue::ExpressionTerminal::toString() { + assert(FEGEN_NOT_IMPLEMENTED_ERROR); +} + +std::string fegen::RightValue::ExpressionTerminal::toStringForTypedef() { + assert(FEGEN_NOT_IMPLEMENTED_ERROR); +} + +std::string fegen::RightValue::ExpressionTerminal::toStringForOpdef() { + assert(FEGEN_NOT_IMPLEMENTED_ERROR); +} + +std::string fegen::RightValue::ExpressionTerminal::toStringForCppKind() { + assert(FEGEN_NOT_IMPLEMENTED_ERROR); +} + +fegen::TypePtr fegen::RightValue::ExpressionTerminal::getType() { + assert(FEGEN_NOT_IMPLEMENTED_ERROR); +} + +// class PlaceHolder +fegen::RightValue::PlaceHolder::PlaceHolder() + : ExpressionTerminal(fegen::RightValue::LiteralKind::MONOSTATE, true) {} + +std::any fegen::RightValue::PlaceHolder::getContent() { + return std::monostate(); +} + +std::string fegen::RightValue::PlaceHolder::toString() { return ""; } + +// class IntegerLiteral +fegen::RightValue::IntegerLiteral::IntegerLiteral(largestInt content, + size_t size) + : ExpressionTerminal(fegen::RightValue::LiteralKind::INT, true), + content(content) {} + +std::any fegen::RightValue::IntegerLiteral::getContent() { + return this->content; +} + +std::string fegen::RightValue::IntegerLiteral::toString() { + return std::to_string(this->content); +} + +std::string fegen::RightValue::IntegerLiteral::toStringForCppKind() { + return std::to_string(this->content); +} + +fegen::TypePtr fegen::RightValue::IntegerLiteral::getType() { + return fegen::Type::getIntegerType(fegen::RightValue::getInteger(this->size)); +} + +// class FloatPointLiteral +fegen::RightValue::FloatPointLiteral::FloatPointLiteral(long double content, + size_t size) + : ExpressionTerminal(fegen::RightValue::LiteralKind::FLOAT, true), + content(content) {} + +std::any fegen::RightValue::FloatPointLiteral::getContent() { + return this->content; +} + +std::string fegen::RightValue::FloatPointLiteral::toString() { + return std::to_string(this->content); +} + +std::string fegen::RightValue::FloatPointLiteral::toStringForCppKind() { + return std::to_string(this->content); +} + +fegen::TypePtr fegen::RightValue::FloatPointLiteral::getType() { + return fegen::Type::getFloatPointType( + fegen::RightValue::getInteger(this->size)); +} + +// class StringLiteral +fegen::RightValue::StringLiteral::StringLiteral(std::string content) + : ExpressionTerminal(fegen::RightValue::LiteralKind::STRING, true), + content(content) {} + +std::any fegen::RightValue::StringLiteral::getContent() { + return this->content; +} + +std::string fegen::RightValue::StringLiteral::toString() { + std::string res; + res.append("\""); + res.append(this->content); + res.append("\""); + return res; +} + +std::string fegen::RightValue::StringLiteral::toStringForCppKind() { + return "\"" + this->content + "\""; +} + +fegen::TypePtr fegen::RightValue::StringLiteral::getType() { + return fegen::Type::getStringType(); +} + +// class TypeLiteral + +// Check params of content and return ture if params are all const expr. +inline bool isParamsConstant(fegen::TypePtr content) { + // for (auto param : content.getParameters()) { + // if (!param->getExpr()->isConstexpr()) { + // return false; + // } + // } + return true; +} + +// Get type of type literal. +fegen::TypePtr getTypeLiteralType(fegen::TypePtr content) { + if (content->getTypeLevel() == 2) { + return fegen::Type::getMetaTemplateType(); + } else if (content->getTypeLevel() == 3) { + return fegen::Type::getMetaType(); + } else { + return fegen::Type::getPlaceHolder(); + } +} + +fegen::RightValue::TypeLiteral::TypeLiteral(fegen::TypePtr content) + : ExpressionTerminal(fegen::RightValue::LiteralKind::TYPE, + content->isConstant()), + content(content) {} + +std::any fegen::RightValue::TypeLiteral::getContent() { return this->content; } + +std::string fegen::RightValue::TypeLiteral::toString() { + return this->content->getTypeName(); +} + +std::string fegen::RightValue::TypeLiteral::toStringForTypedef() { + return this->content->toStringForTypedef(); +} + +std::string fegen::RightValue::TypeLiteral::toStringForOpdef() { + return this->content->toStringForOpdef(); +} + +std::string fegen::RightValue::TypeLiteral::toStringForCppKind() { + return this->content->toStringForCppKind(); +} + +fegen::TypePtr fegen::RightValue::TypeLiteral::getType() { + if (this->content->getTypeLevel() == 2) { + return fegen::Type::getMetaTemplateType(); + } else if (this->content->getTypeLevel() == 3) { + return fegen::Type::getMetaType(); + } else { + assert(false); + } +} + +// Return ture if all Expressions in content are all true. +bool isExpressionListConst( + std::vector> &content) { + for (auto p : content) { + if (!p->isConstexpr()) { + return false; + break; + } + } + return true; +} + +fegen::RightValue::ListLiteral::ListLiteral( + std::vector> &content) + : ExpressionTerminal(fegen::RightValue::LiteralKind::VECTOR, + isExpressionListConst(content)), + content(content) {} + +std::any fegen::RightValue::ListLiteral::getContent() { return this->content; } + +std::string fegen::RightValue::ListLiteral::toString() { + std::string res; + res.append("["); + for (size_t i = 0; i <= this->content.size() - 1; i++) { + res.append(this->content[i]->toString()); + if (i != this->content.size() - 1) { + res.append(", "); + } + } + res.append("]"); + return res; +} + +std::string fegen::RightValue::ListLiteral::toStringForTypedef() { + std::string res; + res.append("["); + for (size_t i = 0; i <= this->content.size() - 1; i++) { + res.append(this->content[i]->toStringForTypedef()); + if (i != this->content.size() - 1) { + res.append(", "); + } + } + res.append("]"); + return res; +} + +std::string fegen::RightValue::ListLiteral::toStringForOpdef() { + std::string res; + res.append("["); + for (size_t i = 0; i <= this->content.size() - 1; i++) { + res.append(this->content[i]->toStringForOpdef()); + if (i != this->content.size() - 1) { + res.append(", "); + } + } + res.append("]"); + return res; +} + +fegen::TypePtr fegen::RightValue::ListLiteral::getType() { + return fegen::Type::getListType(this->content[0]->getType()); +} + +// class LeftValue +fegen::RightValue::LeftValue::LeftValue(fegen::Value *content) + : ExpressionTerminal(fegen::RightValue::LiteralKind::LEFT_VAR, + content->getExpr()->isConstexpr()), + content(content) {} + +std::any fegen::RightValue::LeftValue::getContent() { return this->content; } + +std::string fegen::RightValue::LeftValue::toString() { + return this->content->getName(); +} + +std::string fegen::RightValue::LeftValue::toStringForCppKind() { + return this->content->getName(); +} + +fegen::TypePtr fegen::RightValue::LeftValue::getType() { + return this->content->getType(); +} + +// class FegenRightValue +fegen::RightValue::RightValue( + std::shared_ptr content) + : content(content) {} + +fegen::RightValue::LiteralKind fegen::RightValue::getLiteralKind() { + return this->content->getKind(); +} + +std::string fegen::RightValue::toString() { return this->content->toString(); } + +std::string fegen::RightValue::toStringForTypedef() { + return this->content->toStringForTypedef(); +} + +std::string fegen::RightValue::toStringForOpdef() { + return this->content->toStringForOpdef(); +} + +std::string fegen::RightValue::toStringForCppKind() { + return this->content->toStringForCppKind(); +} + +std::any fegen::RightValue::getContent() { return this->content->getContent(); } + +fegen::TypePtr fegen::RightValue::getType() { return this->content->getType(); } + +std::shared_ptr fegen::RightValue::getExpr() { + return this->content; +} + +bool fegen::RightValue::isConstant() { return this->content->isConstexpr(); } + +fegen::RightValue fegen::RightValue::getPlaceHolder() { + return fegen::RightValue(fegen::RightValue::Expression::getPlaceHolder()); +} + +fegen::RightValue fegen::RightValue::getInteger(largestInt content, + size_t size) { + return fegen::RightValue( + fegen::RightValue::Expression::getInteger(content, size)); +} + +fegen::RightValue fegen::RightValue::getFloatPoint(long double content, + size_t size) { + return fegen::RightValue( + fegen::RightValue::Expression::getFloatPoint(content, size)); +} +fegen::RightValue fegen::RightValue::getString(std::string content) { + return fegen::RightValue(fegen::RightValue::Expression::getString(content)); +} +fegen::RightValue fegen::RightValue::getTypeRightValue(fegen::TypePtr content) { + return fegen::RightValue( + fegen::RightValue::Expression::getTypeRightValue(content)); +} + +fegen::RightValue fegen::RightValue::getList( + std::vector> &content) { + return fegen::RightValue(fegen::RightValue::Expression::getList(content)); +} +fegen::RightValue fegen::RightValue::getLeftValue(fegen::Value *content) { + return fegen::RightValue( + fegen::RightValue::Expression::getLeftValue(content)); +} + +fegen::RightValue fegen::RightValue::getByExpr( + std::shared_ptr expr) { + assert(expr != nullptr); + return fegen::RightValue(expr); +} + +// class FegenValue +fegen::Value::Value(fegen::TypePtr type, std::string name, + fegen::RightValue content) + : type(type), name(std::move(name)), content(std::move(content)) {} + +fegen::Value::Value(const fegen::Value &rhs) + : type(rhs.type), name(rhs.name), content(rhs.content) {} +fegen::Value::Value(fegen::Value &&rhs) + : type(std::move(rhs.type)), name(std::move(rhs.name)), + content(std::move(rhs.content)) {} + +fegen::Value *fegen::Value::get(fegen::TypePtr type, std::string name, + RightValue content) { + return new fegen::Value(type, std::move(name), std::move(content)); +} + +fegen::TypePtr fegen::Value::getType() { return this->type; } + +std::string fegen::Value::getName() { return this->name; } + +void fegen::Value::setContent(fegen::RightValue content) { + this->content = content; +} + +fegen::RightValue::LiteralKind fegen::Value::getContentKind() { + return this->content.getLiteralKind(); +} + +std::string fegen::Value::getContentString() { + return this->content.toString(); +} + +std::string fegen::Value::getContentStringForTypedef() { + return this->content.toStringForTypedef(); +} + +std::string fegen::Value::getContentStringForOpdef() { + return this->content.toStringForOpdef(); +} + +std::string fegen::Value::getContentStringForCppKind() { + return this->content.toStringForCppKind(); +} + +std::shared_ptr fegen::Value::getExpr() { + return this->content.getExpr(); +} + +fegen::ParserRule::ParserRule(std::string content, fegen::ParserNode *src, + antlr4::ParserRuleContext *ctx) + : content(content), src(src), ctx(ctx) {} + +fegen::ParserRule *fegen::ParserRule::get(std::string content, + fegen::ParserNode *src, + antlr4::ParserRuleContext *ctx) { + return new fegen::ParserRule(content, src, ctx); +} + +llvm::StringRef fegen::ParserRule::getContent() { return this->content; } + +bool fegen::ParserRule::addInput(fegen::Value input) { + auto name = input.getName(); + if (this->inputs.count(name) == 0) { + return false; + } + this->inputs.insert({name, new fegen::Value(input)}); + return true; +} + +bool fegen::ParserRule::addReturn(fegen::Value output) { + auto name = output.getName(); + if (this->returns.count(name) == 0) { + return false; + } + this->returns.insert({name, new fegen::Value(output)}); + return true; +} + +void fegen::ParserRule::setSrc(ParserNode *src) { this->src = src; } + +fegen::ParserNode::ParserNode(std::vector &&rules, + antlr4::ParserRuleContext *ctx, + fegen::ParserNode::NodeType ntype) + : rules(rules), ctx(ctx), ntype(ntype) {} + +fegen::ParserNode * +fegen::ParserNode::get(std::vector rules, + antlr4::ParserRuleContext *ctx, + fegen::ParserNode::NodeType ntype) { + return new fegen::ParserNode(std::move(rules), ctx, ntype); +} +fegen::ParserNode *fegen::ParserNode::get(antlr4::ParserRuleContext *ctx, + fegen::ParserNode::NodeType ntype) { + std::vector rules; + return new fegen::ParserNode(std::move(rules), ctx, ntype); +} + +void fegen::ParserNode::addFegenRule(fegen::ParserRule *rule) { + this->rules.push_back(rule); +} + +fegen::ParserNode::~ParserNode() { + for (auto rule : this->rules) { + delete rule; + } +} + +void fegen::Manager::setModuleName(std::string name) { + this->moduleName = name; +} + +std::string getChildrenText(antlr4::tree::ParseTree *ctx) { + std::string ruleText; + for (auto child : ctx->children) { + if (antlr4::tree::TerminalNode::is(child)) { + ruleText.append(child->getText()).append(" "); + } else { + ruleText.append(getChildrenText(child)).append(" "); + } + } + return ruleText; +} + +fegen::Manager::OverloadedType::OverloadedType(TypeDefination *ty) + : tys({ty}) {} +fegen::Manager::OverloadedType::OverloadedType( + std::initializer_list &&tys) + : tys(tys) {} + +fegen::TypeDefination * +fegen::Manager::OverloadedType::OverloadedType::get(unsigned i) { + return this->tys[i]; +} + +fegen::Manager::Manager() {} + +namespace fegen { + +class Emitter { +private: + std::ostream &stream; + int tabCount; + bool isNewLine; + +public: + Emitter() = delete; + Emitter(Emitter &) = delete; + Emitter(Emitter &&) = delete; + Emitter(std::ostream &stream) + : stream(stream), tabCount(0), isNewLine(true) {} + void tab() { tabCount++; } + + void shiftTab() { + tabCount--; + if (tabCount < 0) { + tabCount = 0; + } + } + + void newLine() { + this->stream << std::endl; + isNewLine = true; + } + + std::ostream &operator<<(std::string s) { + if (this->isNewLine) { + for (int i = 0; i <= (this->tabCount - 1); i++) { + this->stream << '\t'; + } + this->isNewLine = false; + } + this->stream << s; + return this->stream; + } +}; +} // namespace fegen + +void fegen::Manager::emitG4() { + std::ofstream fileStream; + fileStream.open(this->moduleName + ".g4"); + fegen::Emitter emitter(fileStream); + emitter << "grammar " << this->moduleName << ";"; + emitter.newLine(); + for (auto node_pair : this->nodeMap) { + auto nodeName = node_pair.first; + auto node = node_pair.second; + emitter << nodeName; + emitter.newLine(); + emitter.tab(); + auto ruleCount = node->rules.size(); + if (ruleCount > 0) { + emitter << ": " << getChildrenText(node->rules[0]->ctx); + emitter.newLine(); + for (size_t i = 1; i <= ruleCount - 1; i++) { + emitter << "| " << getChildrenText(node->rules[i]->ctx); + emitter.newLine(); + } + emitter << ";" << std::endl; + } + emitter.shiftTab(); + emitter.newLine(); + } + fileStream.close(); +} + +// TODO: emit to file +void fegen::Manager::emitTypeDefination() { + std::ofstream fileStream; + fileStream.open(this->moduleName + "Types.td"); + fegen::Emitter emitter(fileStream); + // file head + std::string mn(this->moduleName); + std::transform(mn.begin(), mn.end(), mn.begin(), ::toupper); + emitter << "#ifndef " << mn << "_TYPE_TD"; + emitter.newLine(); + emitter << "#define " << mn << "_TYPE_TD"; + emitter << "\n"; + emitter.newLine(); + + // include files + emitter << "include \"mlir/IR/AttrTypeBase.td\""; + emitter.newLine(); + emitter << "include \"" << this->moduleName << "Dialect.td\""; + emitter << "\n"; + emitter.newLine(); + // Type class defination + std::string typeClassName = this->moduleName + "Type"; + emitter << "class " << typeClassName + << " traits = []>"; + emitter.tab(); + emitter << ": TypeDef {"; + emitter.newLine(); + emitter << "let mnemonic = typeMnemonic;"; + emitter.newLine(); + emitter.shiftTab(); + emitter << "}" << std::endl; + emitter.newLine(); + + for (auto pair : this->typeDefMap) { + auto tyDef = pair.second.get(0); + if (!tyDef->isCustome()) { + continue; + } + auto typeName = pair.first; + // head of typedef + emitter << "def " << typeName << " : " << typeClassName << "<\"" << typeName + << "\", \"" << tyDef->getMnemonic() << "\"> {"; + emitter.newLine(); + emitter.tab(); + // summary + emitter << "let summary = \"This is generated by buddy fegen.\";"; + emitter.newLine(); + // description + emitter << "let description = [{ This is generated by buddy fegen. }];"; + emitter.newLine(); + // parameters + emitter << "let parameters = ( ins"; + emitter.newLine(); + emitter.tab(); + for (size_t i = 0; i <= tyDef->getParameters().size() - 1; i++) { + auto param = tyDef->getParameters()[i]; + auto paramTy = param->getType(); + auto paramName = param->getName(); + auto paramTyStr = paramTy->toStringForTypedef(); + emitter << paramTyStr << ":" << "$" << paramName; + if (i != tyDef->getParameters().size() - 1) { + emitter << ", "; + } + emitter.newLine(); + } + emitter.shiftTab(); + emitter << ");"; + emitter.newLine(); + // assemblyFormat + // TODO: handle list, Type ... + emitter << "let assemblyFormat = [{"; + emitter.newLine(); + emitter.tab(); + emitter << "`<` "; + for (size_t i = 0; i <= tyDef->getParameters().size() - 1; i++) { + auto param = tyDef->getParameters()[i]; + auto paramName = param->getName(); + emitter << "$" << paramName << " "; + if (i != tyDef->getParameters().size() - 1) { + emitter << "`x` "; + } + } + emitter << "`>`"; + emitter.shiftTab(); + emitter.newLine(); + emitter << "}];"; + emitter.newLine(); + emitter.shiftTab(); + emitter << "}"; + emitter.newLine(); + } + emitter.shiftTab(); + emitter << "\n"; + emitter << "#endif // " << mn << "_TYPE_TD"; + fileStream.close(); +} + +void fegen::Manager::emitOpDefination() { + std::ofstream fileStream; + fileStream.open(this->moduleName + "Ops.td"); + fegen::Emitter emitter(fileStream); + + // file head + std::string mn(this->moduleName); + std::transform(mn.begin(), mn.end(), mn.begin(), ::toupper); + emitter << "#ifndef " << mn << "_OPS_TD"; + emitter.newLine(); + emitter << "#define " << mn << "_OPS_TD"; + emitter << "\n"; + emitter.newLine(); + + // TODO: custome include files + // include + emitter << "include \"mlir/IR/BuiltinAttributes.td\""; + emitter.newLine(); + emitter << "include \"mlir/IR/BuiltinTypes.td\""; + emitter.newLine(); + emitter << "include \"mlir/IR/CommonAttrConstraints.td\""; + emitter.newLine(); + emitter << "include \"" << this->moduleName << "Dialect.td\""; + emitter.newLine(); + emitter << "include \"" << this->moduleName << "Types.td\""; + emitter.newLine(); + emitter << "\n"; + + // op class defination + std::string classname = this->moduleName + "Op"; + emitter << "class " << classname + << " traits = []>:"; + emitter.newLine(); + emitter.tab(); + emitter << "Op;"; + emitter << "\n"; + emitter.shiftTab(); + emitter.newLine(); + + // op definations + for (auto pair : this->operationMap) { + auto opName = pair.first; + auto opDef = pair.second; + // head of def + emitter << "def " << opName << " : " << classname << "<\"" << opName + << "\", [Pure]> {"; + emitter.newLine(); + { + emitter.tab(); + // summary and description + emitter << "let summary = \"This is generated by buddy fegen.\";"; + emitter.newLine(); + emitter << "let description = [{This is generated by buddy fegen.}];"; + emitter.newLine(); + // arguments + emitter << "let arguments = ( ins "; + emitter.newLine(); + { + emitter.tab(); + for (auto param : opDef->getArguments()) { + auto paramTyStr = param->getType()->toStringForOpdef(); + auto paramName = param->getName(); + emitter << paramTyStr << " : $" << paramName; + emitter.newLine(); + } + emitter.shiftTab(); + } + emitter << ");"; + emitter.newLine(); + // results + emitter << "let results = (outs "; + emitter.newLine(); + { + emitter.tab(); + for (auto param : opDef->getResults()) { + auto paramTyStr = param->getType()->toStringForOpdef(); + auto paramName = param->getName(); + emitter << paramTyStr << " : $" << paramName; + emitter.newLine(); + } + emitter.shiftTab(); + } + emitter << ");"; + emitter.newLine(); + // end of def + emitter.shiftTab(); + } + emitter << "}"; + emitter.newLine(); + } + + // end of file + emitter << "\n"; + emitter << "#endif // " << mn << "_DIALECT_TD"; + fileStream.close(); +} + +void fegen::Manager::emitDialectDefination() { + std::ofstream fileStream; + fileStream.open(this->moduleName + "Dialect.td"); + fegen::Emitter emitter(fileStream); + + // file head + std::string mn(this->moduleName); + std::transform(mn.begin(), mn.end(), mn.begin(), ::toupper); + emitter << "#ifndef " << mn << "_DIALECT_TD"; + emitter.newLine(); + emitter << "#define " << mn << "_DIALECT_TD"; + emitter << "\n"; + emitter.newLine(); + + // include + emitter << "include \"mlir/IR/OpBase.td\""; + emitter << "\n"; + emitter.newLine(); + + // dialect defination + emitter << "def " << this->moduleName << "_Dialect : Dialect {"; + emitter.newLine(); + emitter.tab(); + emitter << "let name = \"" << this->moduleName << "\";"; + emitter.newLine(); + emitter << "let summary = \"This is generated by buddy fegen.\";"; + emitter.newLine(); + emitter << "let description = [{This is generated by buddy fegen.}];"; + emitter.newLine(); + emitter << "let cppNamespace = \"::mlir::" << this->moduleName << "\";"; + emitter.newLine(); + emitter << "let extraClassDeclaration = [{"; + emitter.newLine(); + emitter.tab(); + emitter << "/// Register all types."; + emitter.newLine(); + emitter << "void registerTypes();"; + emitter.newLine(); + emitter.shiftTab(); + emitter << "}];"; + emitter.newLine(); + emitter.shiftTab(); + emitter << "}"; + emitter.newLine(); + + // end of file + emitter << "#endif // " << mn << "_DIALECT_TD"; + fileStream.close(); +} + +void fegen::Manager::emitTdFiles() { + this->emitDialectDefination(); + this->emitTypeDefination(); + this->emitOpDefination(); +} + +void fegen::Manager::initbuiltinTypes() { + // placeholder type + auto placeholderTypeDefination = fegen::TypeDefination::get( + FEGEN_DIALECT_NAME, FEGEN_PLACEHOLDER, {}, nullptr, false); + this->typeDefMap.insert({FEGEN_PLACEHOLDER, placeholderTypeDefination}); + + // Type + this->typeDefMap.insert( + {FEGEN_TYPE, fegen::TypeDefination::get(FEGEN_DIALECT_NAME, FEGEN_TYPE, + {}, nullptr, false)}); + + // TypeTemplate + this->typeDefMap.insert( + {FEGEN_TYPETEMPLATE, + fegen::TypeDefination::get(FEGEN_DIALECT_NAME, FEGEN_TYPETEMPLATE, {}, + nullptr, false)}); + + // Integer + auto intTydef = fegen::TypeDefination::get(FEGEN_DIALECT_NAME, FEGEN_INTEGER, + {}, nullptr, false); + auto paramOfIntTydef = Value::get( + std::make_shared(RightValue::getInteger(32), intTydef), + "size", fegen::RightValue::getPlaceHolder()); + intTydef->parameters.push_back(paramOfIntTydef); + this->typeDefMap.insert({FEGEN_INTEGER, intTydef}); + + // FloatPoint + this->typeDefMap.insert( + {FEGEN_FLOATPOINT, + fegen::TypeDefination::get( + FEGEN_DIALECT_NAME, FEGEN_FLOATPOINT, + {fegen::Value::get(fegen::Type::getInt32Type(), "size", + fegen::RightValue::getPlaceHolder())}, + nullptr, false)}); + + // String + this->typeDefMap.insert({FEGEN_STRING, fegen::TypeDefination::get( + FEGEN_DIALECT_NAME, FEGEN_STRING, + {}, nullptr, false)}); + + // Vector + this->typeDefMap.insert( + {FEGEN_VECTOR, + fegen::TypeDefination::get( + FEGEN_DIALECT_NAME, FEGEN_VECTOR, + { + fegen::Value::get(fegen::Type::getMetaType(), "elementType", + fegen::RightValue::getPlaceHolder()), + fegen::Value::get(fegen::Type::getInt32Type(), "size", + fegen::RightValue::getPlaceHolder()), + }, + nullptr, false)}); + + // List (this should be ahead of Tensor and Any Type defination) + this->typeDefMap.insert({ + FEGEN_LIST, + {fegen::TypeDefination::get( + FEGEN_DIALECT_NAME, FEGEN_LIST, + {fegen::Value::get(fegen::Type::getMetaType(), "elementType", + fegen::RightValue::getPlaceHolder())}, + nullptr, false), // element type is type instance + fegen::TypeDefination::get( + FEGEN_DIALECT_NAME, FEGEN_LIST, + {fegen::Value::get(fegen::Type::getMetaTemplateType(), "elementType", + fegen::RightValue::getPlaceHolder())}, + nullptr, false)} // element type is type template + }); + + // Tensor + this->typeDefMap.insert( + {FEGEN_TENSOR, + fegen::TypeDefination::get( + FEGEN_DIALECT_NAME, FEGEN_TENSOR, + {fegen::Value::get(fegen::Type::getMetaType(), "elementType", + fegen::RightValue::getPlaceHolder())}, + nullptr, false)}); + + // Optional + this->typeDefMap.insert( + {FEGEN_OPTINAL, + { + fegen::TypeDefination::get( + FEGEN_DIALECT_NAME, FEGEN_OPTINAL, + {fegen::Value::get(fegen::Type::getMetaType(), "elementType", + fegen::RightValue::getPlaceHolder())}, + nullptr, false), // element type is type instance + fegen::TypeDefination::get( + FEGEN_DIALECT_NAME, FEGEN_OPTINAL, + {fegen::Value::get(fegen::Type::getMetaTemplateType(), + "elementType", + fegen::RightValue::getPlaceHolder())}, + nullptr, false) // element type is type template + }}); + + // Any + this->typeDefMap.insert( + {FEGEN_ANY, + { + fegen::TypeDefination::get( + FEGEN_DIALECT_NAME, FEGEN_ANY, + {fegen::Value::get( + fegen::Type::getListTemplate(fegen::Type::getMetaType()), + "elementType", fegen::RightValue::getPlaceHolder())}, + nullptr, false), // elements are Type, ex: Any<[Integer<32>, + // FloatPoint<32>]> + fegen::TypeDefination::get( + FEGEN_DIALECT_NAME, FEGEN_ANY, + {fegen::Value::get(fegen::Type::getListTemplate( + fegen::Type::getMetaTemplateType()), + "elementType", + fegen::RightValue::getPlaceHolder())}, + nullptr, false) // elements are TypeTemplate, ex: Any<[Integer, + // FloatPoint]> + }}); +} + +fegen::TypeDefination *fegen::Manager::getTypeDefination(std::string name) { + auto it = this->typeDefMap.find(name); + if (it != this->typeDefMap.end()) { + return it->second.get(0); + } + assert(false); +} + +fegen::TypeDefination * +fegen::Manager::getOverloadedTypeDefination(std::string name) { + auto it = this->typeDefMap.find(name); + if (it != this->typeDefMap.end()) { + return it->second.get(1); + } + assert(false); +} + +bool fegen::Manager::addTypeDefination(fegen::TypeDefination *tyDef) { + if (this->typeDefMap.count(tyDef->name) != 0) { + return false; + } + this->typeDefMap.insert({tyDef->name, {tyDef}}); + return true; +} + +bool fegen::Manager::addOverloadedTypeDefination(TypeDefination *tyDef) { + auto it = this->typeDefMap.find(tyDef->name); + if (it != this->typeDefMap.end()) { + it->second.tys[1] = tyDef; + } + assert(false); +} + +fegen::Operation *fegen::Manager::getOperationDefination(std::string name) { + return this->operationMap[name]; +} + +bool fegen::Manager::addOperationDefination(fegen::Operation *opDef) { + if (this->operationMap.count(opDef->getOpName()) != 0) { + return false; + } + this->operationMap[opDef->getOpName()] = opDef; + return true; +} + +void fegen::Manager::addStmtContent(antlr4::ParserRuleContext *ctx, + std::any content) { + this->stmtContentMap.insert({ctx, content}); +} + +fegen::Manager &fegen::Manager::getManager() { + static fegen::Manager fmg; + return fmg; +} + +fegen::Manager::~Manager() { + // release nodes + for (auto node_pair : this->nodeMap) { + delete node_pair.second; + } +} + +fegen::TypePtr fegen::inferenceType( + std::vector> operands, + fegen::FegenOperator op) { + // TODO: infer type + return fegen::Type::getInt32Type(); +} + +namespace fegen { + +class StmtVisitor : public FegenParserBaseVisitor { +private: + Manager &manager; + Emitter &emitter; + +public: + StmtVisitor(Emitter &emitter) + : manager(Manager::getManager()), emitter(emitter) {} + std::any visitFunctionDecl(FegenParser::FunctionDeclContext *ctx) override { + auto returnType = + std::any_cast(manager.stmtContentMap[ctx]); + auto functionName = + std::any_cast(manager.stmtContentMap[ctx->funcName()]); + emitter << returnType->toStringForCppKind() << " " << functionName << "("; + auto paraList = std::any_cast>( + manager.stmtContentMap[ctx->funcParams()]); + for (auto para : paraList) { + emitter << para->getType()->toStringForCppKind() << " " + << para->getName(); + if (para != paraList.back()) + emitter << ", "; + } + emitter << "){"; + emitter.tab(); + emitter.newLine(); + this->visit(ctx->statementBlock()); + emitter.shiftTab(); + emitter << "}"; + emitter.newLine(); + return nullptr; + } + std::any + visitStatementBlock(FegenParser::StatementBlockContext *ctx) override { + for (size_t i = 0; i < ctx->statement().size(); i++) { + this->visit(ctx->statement(i)); + if (!(ctx->statement(i)->ifStmt() || ctx->statement(i)->forStmt())) + emitter << ";"; + emitter.newLine(); + } + return nullptr; + } + std::any visitVarDeclStmt(FegenParser::VarDeclStmtContext *ctx) override { + auto varType = std::any_cast(manager.stmtContentMap[ctx]); + auto varName = ctx->identifier()->getText(); + emitter << varType->toStringForCppKind() << " " << varName; + if (ctx->expression()) { + auto expr = std::any_cast>( + manager.stmtContentMap[ctx->expression()]); + emitter << " = " << expr->toStringForCppKind(); + } + return nullptr; + } + std::any visitAssignStmt(FegenParser::AssignStmtContext *ctx) override { + auto varName = ctx->identifier()->getText(); + auto expr = this->manager.getStmtContent( + ctx->expression()); + emitter << varName << " = " << expr->toStringForCppKind(); + return nullptr; + } + // TODO:测试并补足函数调用 + std::any visitFunctionCall(FegenParser::FunctionCallContext *ctx) override { + auto function = + std::any_cast(manager.stmtContentMap[ctx]); + emitter << function->getName() << " ("; + for (auto para : function->getInputTypeList()) { + emitter << para->getName(); + if (para != function->getInputTypeList().back()) + emitter << ", "; + } + // TODO:补充functioncall作为操作数的情况 + emitter << ");"; + emitter.newLine(); + return nullptr; + } + std::any visitIfStmt(FegenParser::IfStmtContext *ctx) override { + this->visit(ctx->ifBlock(0)); + for (size_t i = 1; i < ctx->ifBlock().size(); i++) { + emitter << " else "; + this->visit(ctx->ifBlock(i)); + } + if (ctx->elseBlock()) + this->visit(ctx->elseBlock()); + return nullptr; + } + std::any visitIfBlock(FegenParser::IfBlockContext *ctx) override { + auto expr = std::any_cast>( + manager.stmtContentMap[ctx->expression()]); + + emitter << "if (" << expr->toStringForCppKind() << "){"; + emitter.tab(); + emitter.newLine(); + this->visit(ctx->statementBlock()); + emitter.shiftTab(); + emitter << "}"; + return nullptr; + } + std::any visitElseBlock(FegenParser::ElseBlockContext *ctx) override { + emitter << "else {"; + emitter.tab(); + emitter.newLine(); + this->visit(ctx->statementBlock()); + emitter.shiftTab(); + emitter << "}"; + return nullptr; + } + std::any visitForStmt(FegenParser::ForStmtContext *ctx) override { + if (ctx->varDeclStmt()) { + emitter << "for ("; + this->visit(ctx->varDeclStmt()); + emitter << "; "; + auto expr = std::any_cast>( + manager.stmtContentMap[ctx->expression()]); + emitter << expr->toStringForCppKind() << "; "; + this->visit(ctx->assignStmt(0)); + emitter << ") {"; + } else { + this->visit(ctx->assignStmt(0)); + emitter << " "; + auto expr = std::any_cast>( + manager.stmtContentMap[ctx->expression()]); + emitter << expr->toStringForCppKind() << "; "; + this->visit(ctx->assignStmt(1)); + emitter << ") {"; + } + emitter.tab(); + emitter.newLine(); + this->visit(ctx->statementBlock()); + emitter.shiftTab(); + emitter << "}"; + return nullptr; + } + std::any visitReturnBlock(FegenParser::ReturnBlockContext *ctx) override { + auto expr = std::any_cast>( + manager.stmtContentMap[ctx->expression()]); + emitter << "return " << expr->toStringForCppKind(); + return nullptr; + } + std::any visitOpDecl(FegenParser::OpDeclContext *ctx) override { + return nullptr; + } + + // TODO: add op declaration/invoke +}; + +} // namespace fegen +void fegen::Manager::emitBuiltinFunction( + fegen::FegenParser::FegenSpecContext *moduleAST) { + std::ofstream fileStream; + fileStream.open(this->moduleName + "Function.cpp"); + fegen::Emitter emitter(fileStream); + // Emitter emitter(std::cout); + StmtVisitor visitor(emitter); + visitor.visit(moduleAST); + fileStream.close(); +} diff --git a/frontend/FrontendGen/lib/FegenParser.g4 b/frontend/FrontendGen/lib/FegenParser.g4 new file mode 100644 index 000000000..19d8ccde3 --- /dev/null +++ b/frontend/FrontendGen/lib/FegenParser.g4 @@ -0,0 +1,456 @@ +parser grammar FegenParser; + +options { + tokenVocab = FegenLexer; +} + +fegenSpec + : fegenDecl (prequelConstruct | functionDecl | typeDefinationDecl | statement | opDecl | rules)* EOF + ; + +fegenDecl + : FEGEN identifier + ; + +// preprocess declare +prequelConstruct + : BeginInclude INCLUDE_CONTENT* EndInclude + ; + +// function declare +functionDecl + : typeSpec funcName LeftParen funcParams? RightParen statementBlock + ; + +funcName + : identifier + ; + +funcParams + : typeSpec identifier (Comma typeSpec identifier)* + ; + +// typedef declare +typeDefinationDecl + : TYPEDEF typeDefinationName typeDefinationBlock + ; + +typeDefinationName + : identifier + ; + +typeDefinationBlock + : LeftBrace parametersSpec assemblyFormatSpec? RightBrace + ; + +parametersSpec + : PARAMETERS varDecls + ; + +assemblyFormatSpec + : ASSEMBLY_FORMAT LeftBracket StringLiteral RightBracket + ; + +// opdef declare +opDecl + : OPDEF opName opBlock + ; + +opName + : identifier + ; + +opBlock + : LeftBrace argumentSpec? resultSpec? bodySpec? RightBrace + ; + +argumentSpec + : ARGUMENTS varDecls + ; + +resultSpec + : RESULTS varDecls + ; + +bodySpec + : BODY statementBlock + ; + +// rule definations +rules + : ruleSpec+ + ; + +ruleSpec + : parserRuleSpec + | lexerRuleSpec + ; + +parserRuleSpec + : ParserRuleName Colon ruleBlock Semi + ; + +ruleBlock + : ruleAltList + ; + +ruleAltList + : actionAlt (OR actionAlt)* + ; + +actionAlt + : alternative actionBlock? + ; + +alternative + : element* + ; + +element + : atom (ebnfSuffix |) + | ebnf + ; + +atom + : terminalDef + | ruleref + | notSet + ; + +// terminal rule reference +terminalDef + : LexerRuleName + | StringLiteral + ; + +// parser rule reference +ruleref + : ParserRuleName + ; + +notSet + : Tilde setElement + | Tilde blockSet + ; + +setElement + : LexerRuleName + | StringLiteral + | characterRange + ; + +characterRange + : StringLiteral Range StringLiteral + ; + +blockSet + : LeftParen setElement (OR setElement)* RightParen + ; + +ebnfSuffix + : QuestionMark QuestionMark? + | Star QuestionMark? + | Plus QuestionMark? + ; + +ebnf + : block blockSuffix? + ; + +block + : LeftParen altList RightParen + ; + +blockSuffix + : ebnfSuffix + ; + +altList + : alternative (OR alternative)* + ; + +// lexer rule +lexerRuleSpec + : LexerRuleName Colon lexerRuleBlock Semi + ; + +lexerRuleBlock + : lexerAltList + ; + +lexerAltList + : lexerAlt (OR lexerAlt)* + ; + +lexerAlt + : lexerElements lexerCommands? + | + ; + +// E.g., channel(HIDDEN), skip, more, mode(INSIDE), push(INSIDE), pop +lexerCommands + : Arror lexerCommand (Comma lexerCommand)* + ; + +lexerCommand + : lexerCommandName + ; + +lexerCommandName + : identifier + ; + +lexerElements + : lexerElement+ + | + ; + +lexerElement + : lexerAtom ebnfSuffix? + | lexerBlock ebnfSuffix? + ; + +lexerAtom + : characterRange + | terminalDef + | notSet + | Dot + ; + +lexerBlock + : LeftParen lexerAltList RightParen + ; + +// action block declare +actionBlock + : LeftBrace inputsSpec? returnsSpec? actionSpec? RightBrace + ; + +inputsSpec + : INPUTS varDecls + ; + +varDecls + : LeftBracket typeSpec identifier (Comma typeSpec identifier)* RightBracket + ; + +prefixedName + : identifier (Dot identifier)? + ; + +identifier + : LexerRuleName + | ParserRuleName + ; + +returnsSpec + : RETURNS varDecls + ; + +actionSpec + : ACTIONS statementBlock + ; + +statementBlock + : LeftBrace statement* RightBrace + ; + +statement + : varDeclStmt Semi + | assignStmt Semi + | functionCall Semi + | opInvokeStmt Semi + | ifStmt + | forStmt + | returnBlock Semi + ; + +varDeclStmt + : typeSpec identifier (Assign expression)? + ; + +assignStmt + : identifier Assign expression + ; + +functionCall + : funcName LeftParen (expression (Comma expression)*)? RightParen + ; + +opInvokeStmt + : opName LeftParen opParams? (Comma opResTypeParams)? RightParen+ + ; + +opParams + : identifier (Comma identifier)* + ; + +opResTypeParams + : typeInstance (Comma typeInstance)* + ; + +ifStmt + : ifBlock (ELSE ifBlock)* (elseBlock)? + ; + +ifBlock: + IF LeftParen expression RightParen statementBlock + ; + +elseBlock + : ELSE statementBlock + ; + +forStmt + : FOR LeftParen (assignStmt | varDeclStmt) Semi expression Semi assignStmt RightParen statementBlock + ; + +returnBlock + : RETURN expression + ; + +// expression +expression + : andExpr (Logic_OR andExpr)* + ; + +andExpr + : equExpr (AND equExpr )* + ; + +equExpr + : compareExpr ((EQUAL | NOT_EQUAL) compareExpr)* + ; + +compareExpr + : addExpr ((Less | LessEqual | Greater | GreaterEqual) addExpr)* + ; + +addExpr + : term ((Plus | Minus) term)* + ; + +term + : powerExpr ((Star | Div | MOD) powerExpr)* + ; + +powerExpr + : unaryExpr (StarStar unaryExpr)* + ; + +unaryExpr + : (Minus | Plus | Exclamation)? primaryExpr + ; + +parenSurroundedExpr + : LeftParen expression RightParen + ; + +primaryExpr + : constant + | identifier + | functionCall + | parenSurroundedExpr + | contextMethodInvoke + | typeSpec + | variableAccess + ; + +constant + : numericLiteral + | charLiteral + | boolLiteral + | listLiteral + ; + +// ex: $ctx(0).getText() +contextMethodInvoke + : Dollar identifier LeftParen intLiteral? RightParen Dot functionCall + ; + +variableAccess + : identifier LeftBracket expression RightBracket + ; + +numericLiteral + : intLiteral + | realLiteral + ; + +intLiteral + : UnsignedInt + | (Plus | Minus) UnsignedInt + ; + +realLiteral + : ScienceReal + ; + +charLiteral + : StringLiteral + ; + +boolLiteral + : ConstBoolean + ; + +listLiteral + : LeftBracket (expression (Comma expression)*)? RightBracket + ; + +// type system +typeSpec + : valueKind? typeInstance # typeInstanceSpec + | valueKind? typeTemplate # typeTemplateSpce + | valueKind? collectType # collectTypeSpec + ; + +valueKind + : CPP + | OPERAND + | ATTRIBUTE + ; + + +typeInstance + : typeTemplate Less typeTemplateParam (Comma typeTemplateParam)* Greater + | builtinTypeInstances + | identifier + ; + +typeTemplate + : prefixedName + | builtinTypeTemplate + | TYPE + ; + +typeTemplateParam + : expression + | builtinTypeInstances + ; + +builtinTypeInstances + : BOOL + | INT + | FLOAT + | DOUBLE + | CHAR + | STRING + ; + +builtinTypeTemplate + : INTEGER + | FLOATPOINT + | TENSOR + | VECTOR + ; + +collectType + : collectProtoType Less expression Greater + ; + +collectProtoType + : ANY + | LIST + | OPTIONAL + ; \ No newline at end of file diff --git a/frontend/FrontendGen/lib/FegenVisitor.cpp b/frontend/FrontendGen/lib/FegenVisitor.cpp new file mode 100644 index 000000000..761fe0529 --- /dev/null +++ b/frontend/FrontendGen/lib/FegenVisitor.cpp @@ -0,0 +1,11 @@ +#include "FegenVisitor.h" + +bool fegen::checkParams(std::vector &expected, + std::vector &actual) { + return true; +} + +bool fegen::checkListLiteral( + std::vector> &listLiteral) { + return true; +} \ No newline at end of file diff --git a/frontend/FrontendGen/lib/Lexer.cpp b/frontend/FrontendGen/lib/Lexer.cpp deleted file mode 100644 index 6cce20df8..000000000 --- a/frontend/FrontendGen/lib/Lexer.cpp +++ /dev/null @@ -1,198 +0,0 @@ -//====- Lexer.cpp --------------------------------------------------------===// -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// - -#include "Lexer.h" -#include "llvm/Support/raw_ostream.h" -using namespace frontendgen; -/// some function about handing characters. -namespace charinfo { -inline bool isASCLL(char ch) { return static_cast(ch) <= 127; } - -inline bool isWhitespace(char ch) { - return isASCLL(ch) && (ch == ' ' || ch == '\t' || ch == '\f' || ch == '\v' || - ch == '\r' || ch == '\n'); -} - -inline bool isIdentifierHead(char ch) { - return isASCLL(ch) && - (ch == '_' || (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z')); -} - -inline bool isDigit(char ch) { return isASCLL(ch) && (ch >= '0' && ch <= '9'); } - -inline bool isIdentifierBody(char ch) { - return isIdentifierHead(ch) || isDigit(ch); -} -} // namespace charinfo - -/// Add keyword to keywordmap. -void KeyWordManager::addKeyWord(llvm::StringRef name, tokenKinds kind) { - keywordMap.insert(std::make_pair(name, kind)); -} - -/// A function add all keywords. -void KeyWordManager::addKeyWords() { -#define KEYWORD(NAME, FLAG) addKeyWord(#NAME, tokenKinds::kw_##NAME); -#include "Token.def" -} - -/// Determine if a string is a keyword. -tokenKinds KeyWordManager::getKeyWord(llvm::StringRef name, tokenKinds kind) { - auto result = keywordMap.find(name); - if (result != keywordMap.end()) - return result->second; - return kind; -} - -bool Token::is(tokenKinds kind) { return kind == tokenKind; } - -llvm::SMLoc Token::getLocation() { return llvm::SMLoc::getFromPointer(start); } -//// Get next token. -void Lexer::next(Token &token) { - // Skip whitespace. - while (*curPtr && charinfo::isWhitespace(*curPtr)) - curPtr++; - if (!*curPtr) { - token.setTokenKind(tokenKinds::eof); - return; - } - // Get identifier. - if (charinfo::isIdentifierHead(*curPtr)) { - identifier(token); - return; - } else if (charinfo::isDigit(*curPtr)) { - number(token); - return; - } else if (*curPtr == ';') { - formToken(token, curPtr + 1, tokenKinds::semi); - return; - } else if (*curPtr == ':') { - formToken(token, curPtr + 1, tokenKinds::colon); - return; - } else if (*curPtr == '\'') { - formToken(token, curPtr + 1, tokenKinds::apostrophe); - return; - } else if (*curPtr == '(') { - formToken(token, curPtr + 1, tokenKinds::parentheseOpen); - return; - } else if (*curPtr == ')') { - formToken(token, curPtr + 1, tokenKinds::parentheseClose); - return; - } else if (*curPtr == '*') { - formToken(token, curPtr + 1, tokenKinds::asterisk); - return; - } else if (*curPtr == '?') { - formToken(token, curPtr + 1, tokenKinds::questionMark); - return; - } else if (*curPtr == '+') { - formToken(token, curPtr + 1, tokenKinds::plus); - return; - } else if (*curPtr == '=') { - formToken(token, curPtr + 1, tokenKinds::equal); - return; - } else if (*curPtr == '{') { - formToken(token, curPtr + 1, tokenKinds::curlyBlacketOpen); - return; - } else if (*curPtr == '}') { - formToken(token, curPtr + 1, tokenKinds::curlyBlacketClose); - return; - } else if (*curPtr == '$') { - formToken(token, curPtr + 1, tokenKinds::dollar); - return; - } else if (*curPtr == ',') { - formToken(token, curPtr + 1, tokenKinds::comma); - return; - } else if (*curPtr == '<') { - formToken(token, curPtr + 1, tokenKinds::angleBracketOpen); - return; - } else if (*curPtr == '>') { - formToken(token, curPtr + 1, tokenKinds::angleBracketClose); - return; - } else if (*curPtr == '[') { - formToken(token, curPtr + 1, tokenKinds::squareBracketOpen); - return; - } else if (*curPtr == ']') { - formToken(token, curPtr + 1, tokenKinds::squareBracketClose); - return; - } else if (*curPtr == '"') { - formToken(token, curPtr + 1, tokenKinds::doubleQuotationMark); - return; - } - token.tokenKind = tokenKinds::unknown; -} - -void Lexer::identifier(Token &token) { - const char *start = curPtr; - const char *end = curPtr + 1; - while (charinfo::isIdentifierBody(*end)) - ++end; - llvm::StringRef name(start, end - start); - tokenKinds kind = keywordManager.getKeyWord(name, tokenKinds::identifier); - formToken(token, end, kind); -} - -void Lexer::formToken(Token &token, const char *tokenEnd, tokenKinds kind) { - int length = tokenEnd - curPtr; - token.start = curPtr; - token.length = length; - token.tokenKind = kind; - curPtr = tokenEnd; -} - -void Lexer::number(Token &token) { - const char *end = curPtr; - end++; - while (charinfo::isDigit(*end)) - end++; - formToken(token, end, tokenKinds::number); -} -/// Get the corresponding content according to start and end. -llvm::StringRef Lexer::getMarkContent(std::string start, std::string end) { - while (*curPtr && charinfo::isWhitespace(*curPtr)) - curPtr++; - int index = start.find(*curPtr); - if (index == -1) - return llvm::StringRef(); - char s = start[index]; - char e = end[index]; - const char *endPtr = curPtr + 1; - int number = 1; - if (s == e) - while (*endPtr != e) - endPtr++; - else - while (number) { - if (*endPtr == s) - number++; - if (*endPtr == e) - number--; - if (number) - endPtr++; - } - endPtr++; - llvm::StringRef content(curPtr, endPtr - curPtr); - curPtr = endPtr; - return content; -} -/// Get the corresponding content according to statr and ch. -llvm::StringRef Lexer::getEndChContent(const char *start, char ch) { - const char *endPtr = curPtr; - while (*endPtr != ch) - endPtr++; - endPtr++; - curPtr = endPtr; - return llvm::StringRef(start, endPtr - start); -} diff --git a/frontend/FrontendGen/lib/Parser.cpp b/frontend/FrontendGen/lib/Parser.cpp deleted file mode 100644 index 152462fce..000000000 --- a/frontend/FrontendGen/lib/Parser.cpp +++ /dev/null @@ -1,403 +0,0 @@ -//====- Parser.cpp -------------------------------------------------------===// -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// - -#include "Parser.h" -#include "AST.h" -#include "Lexer.h" -#include "Sema.h" -#include "unistd.h" -#include "llvm/Support/raw_ostream.h" -using namespace frontendgen; - -void Parser::advance() { lexer.next(token); } - -void Parser::lookToken() { - while (token.getKind() != tokenKinds::eof) { - llvm::outs() << token.getContent() << '\n'; - llvm::outs() << "token type:" << token.getTokenName() << '\n'; - advance(); - } -} - -/// If current token's kind is expected kind, get next token. -/// If not, an error is reported. -bool Parser::consume(tokenKinds expectTok) { - if (token.is(expectTok)) { - advance(); - return true; - } - lexer.getDiagnostic().report(token.getLocation(), - DiagnosticEngine::err_expected, - tokenNameMap[expectTok], token.getTokenName()); - return false; -} - -/// If current token's kind is expected kind, get next token. -/// If not, do nothing. -bool Parser::consumeNoAdvance(tokenKinds expectTok) { - if (token.is(expectTok)) - return true; - lexer.getDiagnostic().report(token.getLocation(), - DiagnosticEngine::err_expected, - tokenNameMap[expectTok], token.getTokenName()); - return false; -} - -/// Parser the file, and return a Module, it store all information -/// to generate code. -Module *Parser::parser() { - Module *module = new Module(); - compilEngine(module); - return module; -} - -/// Parse keyword op, dialect and rule. -void Parser::compilEngine(Module *module) { - // rules store all rule ast. - std::vector rules; - // A file can only store one dialect. - Dialect *dialect = nullptr; - // ops store all op. - std::vector ops; - while (token.getKind() != tokenKinds::eof) { - if (token.is(tokenKinds::kw_rule)) { - advance(); - if (!consumeNoAdvance(tokenKinds::identifier)) - return; - Rule *rule = - new Rule(token.getContent(), token.getLocation(), AntlrBase::rule); - advance(); - parserRules(rule); - rules.push_back(rule); - consume(tokenKinds::semi); - } else if (token.is(tokenKinds::kw_dialect)) { - advance(); - if (!consumeNoAdvance(tokenKinds::identifier)) - return; - llvm::StringRef defName = token.getContent(); - advance(); - parserDialect(dialect, defName); - } else if (token.is(tokenKinds::kw_op)) { - advance(); - if (!parserOp(ops, token.getContent())) { - action.actOnModule(module, rules, dialect, ops); - return; - } - } else { - lexer.getDiagnostic().report( - token.getLocation(), DiagnosticEngine::err_expected, - "keyword rule, dialect or op", token.getTokenName()); - action.actOnModule(module, rules, dialect, ops); - return; - } - } - action.actOnModule(module, rules, dialect, ops); -} - -/// Parser the rule and fill nodes of rule ast. -void Parser::parserRules(Rule *rule) { - if (!consumeNoAdvance(tokenKinds::colon)) - return; - // A rule contains many generative. - std::vector generators; - while (token.getKind() != tokenKinds::semi && - token.getKind() == tokenKinds::colon) { - advance(); - GeneratorAndOthers *generatorAndOthers = new GeneratorAndOthers(); - parserGenerator(generatorAndOthers); - generators.push_back(generatorAndOthers); - if (!token.is(tokenKinds::colon) && !token.is(tokenKinds::semi)) { - lexer.getDiagnostic().report(token.getLocation(), - DiagnosticEngine::err_expected, - "colon or semi", token.getTokenName()); - return; - } - } - - // Fill the rule ast. - action.actOnRule(rule, generators); -} - -/// Parser a generator and fill a node in generator. -void Parser::parserGenerator(GeneratorAndOthers *generatorAndOthers) { - while (token.is(tokenKinds::identifier) || token.is(tokenKinds::apostrophe) || - token.is(tokenKinds::plus) || token.is(tokenKinds::asterisk) || - token.is(tokenKinds::parentheseOpen) || - token.is(tokenKinds::parentheseClose) || - token.is(tokenKinds::questionMark) || - token.is(tokenKinds::curlyBlacketOpen)) { - if (token.is(tokenKinds::identifier)) - parserIdentifier(generatorAndOthers); - else if (token.is(tokenKinds::apostrophe)) - parserTerminator(generatorAndOthers); - else if (token.is(tokenKinds::curlyBlacketOpen)) - parserCurlyBracketOpen(generatorAndOthers); - else - parserPBExpression(generatorAndOthers); - } -} - -void Parser::parserCurlyBracketOpen(GeneratorAndOthers *generatorAndOthers) { - advance(); - llvm::SMLoc location = token.getLocation(); - if (token.getContent() == "builder") { - llvm::SmallVector builderNames; - llvm::SmallVector builderIdxs; - advance(); - if (!consume(tokenKinds::equal)) - return; - while (token.is(identifier)) { - int index; - if ((index = token.getContent().find('_')) == -1) - lexer.getDiagnostic().report(token.getLocation(), - DiagnosticEngine::err_builder_fail); - llvm::StringRef builderOpName = token.getContent().substr(0, index); - std::string opBulderIdx = - token.getContent() - .substr(index + 1, token.getContent().size() - index) - .str(); - builderNames.push_back(builderOpName); - builderIdxs.push_back(std::stoi(opBulderIdx)); - advance(); - if (token.is(tokenKinds::comma)) - advance(); - } - generatorAndOthers->setbuilderNames(builderNames); - generatorAndOthers->setbuilderIdxs(builderIdxs); - } else { - lexer.getDiagnostic().report(location, - DiagnosticEngine::err_only_supported_builder); - return; - } - - consume(tokenKinds::curlyBlacketClose); -} - -/// Check if the identifier is a terminator. -AntlrBase::baseKind Parser::getAntlrBaseKind(llvm::StringRef name) { - if (terminators.isTerminator(name)) - return AntlrBase::baseKind::terminator; - return AntlrBase::baseKind::rule; -} - -/// processing the identifier, get the identifier's kind which stores -/// in the ast. -void Parser::parserIdentifier(GeneratorAndOthers *generatorAndOthers) { - AntlrBase::baseKind baseKind = getAntlrBaseKind(token.getContent()); - AntlrBase *r = nullptr; - if (baseKind == AntlrBase::baseKind::rule) - r = new Rule(token.getContent(), token.getLocation(), baseKind); - else if (baseKind == AntlrBase::AntlrBase::terminator) - r = new Terminator(token.getContent(), token.getLocation(), baseKind); - generatorAndOthers->getGenerator().push_back(r); - advance(); -} - -/// We support user-defined terminator.For example, we can write a 'terminator' -/// in a rule. -void Parser::parserTerminator(GeneratorAndOthers *generatorAndOthers) { - advance(); - AntlrBase *terminator = new Terminator( - token.getContent(), token.getLocation(), AntlrBase::terminator); - generatorAndOthers->getGenerator().push_back(terminator); - terminators.addCustomTerminators(token.getContent()); - advance(); - consume(tokenKinds::apostrophe); -} - -void Parser::parserPBExpression(GeneratorAndOthers *generatorAndOthers) { - AntlrBase *r = new Terminator(token.getContent(), token.getLocation(), - AntlrBase::pbexpression); - generatorAndOthers->getGenerator().push_back(r); - advance(); -} -/// Parser dialect keyword and fill all information in the dialect. -void Parser::parserDialect(Dialect *&dialect, llvm::StringRef defName) { - dialect = new Dialect(); - llvm::StringRef name; - llvm::StringRef cppNamespace; - while (token.is(tokenKinds::colon)) { - advance(); - if (token.getContent().str() == "name") { - advance(); - consumeNoAdvance(tokenKinds::equal); - name = lexer.getMarkContent("\"", "\""); - advance(); - } else if (token.getContent().str() == "cppNamespace") { - advance(); - consumeNoAdvance(tokenKinds::equal); - cppNamespace = lexer.getMarkContent("\"", "\""); - advance(); - } - } - action.actOnDialect(dialect, defName, name, cppNamespace); - advance(); -} - -/// Parser op keyword and fill all information in the ops. -bool Parser::parserOp(std::vector &ops, llvm::StringRef opName) { - DAG *arguments = nullptr; - DAG *results = nullptr; - std::vector builders; - advance(); - while (token.is(tokenKinds::colon)) { - advance(); - if (token.getContent() == "arguments") { - advance(); - if (!consumeNoAdvance(tokenKinds::equal)) - return false; - parserDAG(arguments); - advance(); - } else if (token.getContent() == "results") { - advance(); - if (!consumeNoAdvance(tokenKinds::equal)) - return false; - parserDAG(results); - advance(); - } else if (token.getContent() == "builders") { - advance(); - if (!consume(tokenKinds::equal)) - return false; - parserBuilders(builders); - advance(); - } else { - lexer.getDiagnostic().report(token.getLocation(), - DiagnosticEngine::err_not_supported_element, - token.getContent()); - return false; - } - } - if (!consume(tokenKinds::semi)) { - llvm::outs() << token.getContent(); - return false; - } - // Fill all information in the ops. - action.actOnOps(ops, opName, arguments, results, builders); - return true; -} - -/// parser DAG structure and fill all information in the arguments. -void Parser::parserDAG(DAG *&arguments) { - DAG dag; - advance(); - consume(tokenKinds::parentheseOpen); - llvm::StringRef dagOperator = token.getContent(); - advance(); - while (token.is(tokenKinds::identifier) || - token.is(tokenKinds::doubleQuotationMark)) { - int number = 0; - llvm::StringRef operandName; - llvm::StringRef operand; - llvm::StringRef value; - // If the operand provides a default value. - if (token.getContent() == "CArg") { - parserCArg(operand, value); - } else if (token.getContent() == "AnyTypeOf") { - const char *start = token.getContent().data(); - advance(); - if (!consumeNoAdvance(tokenKinds::angleBracketOpen)) - return; - operand = llvm::StringRef( - start, - 9 + lexer.getEndChContent(token.getContent().data(), '>').size()); - advance(); - } else if (token.is(tokenKinds::doubleQuotationMark)) { - // If the operand's type is cpp type. - operand = lexer.getEndChContent(token.getContent().data(), '"'); - advance(); - } else { - // If the operand's type is TableGen type. - operand = token.getContent(); - advance(); - if (token.is(tokenKinds::angleBracketOpen)) { - number++; - advance(); - if (token.is(tokenKinds::squareBracketOpen)) { - advance(); - number++; - } - llvm::StringRef type = token.getContent(); - advance(); - if (token.is(tokenKinds::squareBracketClose)) { - advance(); - number++; - } - consume(tokenKinds::angleBracketClose); - number++; - operand = llvm::StringRef(operand.data(), - operand.size() + number + type.size()); - } - } - // If operand is named. - if (token.is(tokenKinds::colon)) { - advance(); - advance(); - operandName = token.getContent(); - advance(); - } - dag.addOperand(operand, operandName); - if (!value.empty()) - dag.setValue(operand, value); - if (token.is(tokenKinds::comma)) - advance(); - } - dag.setDagOperatpr(dagOperator); - consumeNoAdvance(tokenKinds::parentheseClose); - // fill all information in the arguments. - action.actOnDag(arguments, dag); -} - -/// Parser opBuilder in the op. -void Parser::parserBuilders(std::vector &builders) { - if (!consume(tokenKinds::squareBracketOpen)) - return; - while (token.getContent() == "OpBuilder") { - DAG *dag = nullptr; - llvm::StringRef code; - advance(); - if (!consumeNoAdvance(tokenKinds::angleBracketOpen)) - return; - // Parser DAG. - parserDAG(dag); - advance(); - if (token.is(tokenKinds::comma)) { - // Parser code. - parserCode(code); - advance(); - } - if (!consume(tokenKinds::angleBracketClose)) - return; - Builder *builder = new Builder(dag, code); - builders.push_back(builder); - if (token.is(tokenKinds::comma)) - advance(); - } - consumeNoAdvance(tokenKinds::squareBracketClose); -} - -void Parser::parserCode(llvm::StringRef &code) { - code = lexer.getMarkContent("[", "]"); -} - -void Parser::parserCArg(llvm::StringRef &operand, llvm::StringRef &value) { - advance(); - consumeNoAdvance(tokenKinds::angleBracketOpen); - operand = lexer.getMarkContent("\"", "\""); - advance(); - value = lexer.getMarkContent("\"", "\""); - advance(); - consume(tokenKinds::angleBracketClose); -} diff --git a/frontend/FrontendGen/lib/Scope.cpp b/frontend/FrontendGen/lib/Scope.cpp new file mode 100644 index 000000000..9294cb889 --- /dev/null +++ b/frontend/FrontendGen/lib/Scope.cpp @@ -0,0 +1,85 @@ +#include "Scope.h" + +// SymbolTable +template void fegen::SymbolTable::add(std::string name, T *e) { + this->table.insert({name, e}); +} + +template T *fegen::SymbolTable::get(std::string name) { + return this->table[name]; +} + +template bool fegen::SymbolTable::exist(std::string name) { + return (this->table.count(name) > 0); +} + +template fegen::SymbolTable::~SymbolTable() { + for (auto pair : this->table) { + delete pair.second; + } +} + +// FegenScope +fegen::FegenScope::FegenScope(unsigned int scopeId, + fegen::FegenScope *parentScope) + : scopeId(scopeId), parentScope(parentScope) {} + +fegen::Value *fegen::FegenScope::findVar(std::string name) { + return this->varTable.get(name); +} + +void fegen::FegenScope::addVar(fegen::Value *var) { + this->varTable.add(var->getName(), var); +} + +bool fegen::FegenScope::isExistVar(std::string name) { + return this->varTable.exist(name); +} + +fegen::ScopeStack::ScopeStack() : count(1) { + this->globalScope = new fegen::FegenScope(0, nullptr); + this->currentScope = this->globalScope; + this->scopeStack.push(this->globalScope); + this->scopes.push_back(this->globalScope); +} + +fegen::ScopeStack::~ScopeStack() { + for (auto scope : this->scopes) { + delete scope; + } +} + +fegen::ScopeStack &fegen::ScopeStack::getScopeStack() { + static fegen::ScopeStack sstack; + return sstack; +} + +void fegen::ScopeStack::pushScope() { + auto newScope = new fegen::FegenScope(this->count++, this->currentScope); + this->scopeStack.push(newScope); + this->scopes.push_back(newScope); + this->currentScope = newScope; +} + +void fegen::ScopeStack::popScope() { + this->scopeStack.pop(); + this->currentScope = this->scopeStack.top(); +} +bool fegen::ScopeStack::attemptAddVar(fegen::Value *var) { + if (this->currentScope->isExistVar(var->getName())) { + return false; + } + this->currentScope->addVar(var); + return true; +} + +fegen::Value *fegen::ScopeStack::attemptFindVar(std::string name) { + auto p = this->currentScope; + while (p != nullptr) { + if (p->isExistVar(name)) { + return p->findVar(name); + } + p = p->parentScope; + } + return nullptr; +} \ No newline at end of file diff --git a/frontend/FrontendGen/lib/Sema.cpp b/frontend/FrontendGen/lib/Sema.cpp deleted file mode 100644 index 00cc005e2..000000000 --- a/frontend/FrontendGen/lib/Sema.cpp +++ /dev/null @@ -1,54 +0,0 @@ -//====- Sema.cpp ---------------------------------------------------------===// -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// - -#include "Sema.h" -#include "llvm/Support/raw_ostream.h" -using namespace frontendgen; - -/// Set Module's nodes. -void Sema::actOnModule(Module *module, std::vector &rules, - Dialect *&dialect, std::vector &ops) { - module->setRules(rules); - module->seDialect(dialect); - module->setOps(ops); -} -/// Set Rule's node. -void Sema::actOnRule(Rule *rule, - std::vector &generators) { - rule->setGenerators(generators); -} - -/// Set Dialect's nodes. -void Sema::actOnDialect(Dialect *dialect, llvm::StringRef defName, - llvm::StringRef name, llvm::StringRef cppNamespace) { - dialect->setDefName(defName); - dialect->setName(name); - dialect->setCppNamespace(cppNamespace); -} - -/// Make a op and make it in the ops. -void Sema::actOnOps(std::vector &ops, llvm::StringRef opName, - DAG *arguments, DAG *results, - std::vector &builders) { - Op *op = new Op(); - op->setOpName(opName); - op->setArguments(arguments); - op->setResults(results); - op->setBuilders(builders); - ops.push_back(op); -} - -void Sema::actOnDag(DAG *&arguments, DAG &dag) { arguments = new DAG(dag); }