Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions include/circt/Dialect/Arc/ArcOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -643,8 +643,16 @@ def SimInstantiateOp : ArcOp<"sim.instantiate",
}];
let regions = (region SizedRegion<1>:$body);

let arguments = (ins OptionalAttr<FlatSymbolRefAttr>:$runtimeModel,
OptionalAttr<StrAttr>:$runtimeArgs);

let hasRegionVerifier = 1;
let hasCustomAssemblyFormat = 1;

let builders = [OpBuilder<(ins),
[{
build($_builder, $_state, FlatSymbolRefAttr{}, StringAttr{});
}]>];
}

def SimSetInputOp : ArcOp<"sim.set_input",
Expand Down Expand Up @@ -949,4 +957,17 @@ def VectorizeReturnOp : ArcOp<"vectorize.return", [
let assemblyFormat = "operands attr-dict `:` qualified(type(operands))";
}

def RuntimeModelOp : ArcOp<"runtime.model", [
HasParent<"::mlir::ModuleOp">, Pure, Symbol
]> {
let summary = "TODO";
let arguments = (ins SymbolNameAttr:$sym_name,
StrAttr:$name,
I64Attr:$numStateBytes);

let assemblyFormat = [{
$sym_name $name `numStateBytes` $numStateBytes attr-dict
}];
}

#endif // CIRCT_DIALECT_ARC_ARCOPS_TD
9 changes: 9 additions & 0 deletions include/circt/Dialect/Arc/ArcPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,15 @@ def InferStateProperties : Pass<"arc-infer-state-properties",
];
}

def InsertRuntime : Pass<"arc-insert-runtime", "mlir::ModuleOp"> {
let summary = "TODO";
let dependentDialects = ["mlir::LLVM::LLVMDialect"];
let options = [
Option<"extraArgs", "extra-args", "std::string", "",
"Extra arguments passed to the runtime when creating simulation instances">
];
}

def IsolateClocks : Pass<"arc-isolate-clocks", "mlir::ModuleOp"> {
let summary = "Group clocked operations into clock domains";
let constructor = "circt::arc::createIsolateClocksPass()";
Expand Down
6 changes: 6 additions & 0 deletions include/circt/Dialect/Arc/ModelInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#ifndef CIRCT_DIALECT_ARC_MODELINFO_H
#define CIRCT_DIALECT_ARC_MODELINFO_H

#include "circt/Dialect/Arc/ArcOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/raw_ostream.h"
Expand Down Expand Up @@ -48,6 +49,11 @@ struct ModelInfo {
finalFnSym(finalFnSym) {}
};

struct ModelInfoAnalysis {
explicit ModelInfoAnalysis(Operation *container);
llvm::DenseMap<ModelOp, ModelInfo> infoMap;
};

/// Collects information about states within the provided Arc model storage
/// `storage`, assuming default `offset`, and adds it to `states`.
mlir::LogicalResult collectStates(mlir::Value storage, unsigned offset,
Expand Down
192 changes: 161 additions & 31 deletions lib/Conversion/ArcToLLVM/LowerArcToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"

#include "circt/Tools/arcilator/ArcRuntime/Common.h"
#include "circt/Tools/arcilator/ArcRuntime/JITBind.h"

#include <cstdlib>

#define DEBUG_TYPE "lower-arc-to-llvm"

namespace circt {
Expand Down Expand Up @@ -373,6 +378,8 @@ struct SimInstantiateOpLowering
.getValue());
ModelInfoMap &model = modelIt->second;

bool useRuntime = op.getRuntimeModel().has_value();

ModuleOp moduleOp = op->getParentOfType<ModuleOp>();
if (!moduleOp)
return failure();
Expand All @@ -382,27 +389,63 @@ struct SimInstantiateOpLowering
// FIXME: like the rest of MLIR, this assumes sizeof(intptr_t) ==
// sizeof(size_t) on the target architecture.
Type convertedIndex = typeConverter->convertType(rewriter.getIndexType());

FailureOr<LLVM::LLVMFuncOp> mallocFunc =
LLVM::lookupOrCreateMallocFn(rewriter, moduleOp, convertedIndex);
if (failed(mallocFunc))
return mallocFunc;

FailureOr<LLVM::LLVMFuncOp> freeFunc =
LLVM::lookupOrCreateFreeFn(rewriter, moduleOp);
if (failed(freeFunc))
return freeFunc;

Location loc = op.getLoc();
Value numStateBytes = LLVM::ConstantOp::create(
rewriter, loc, convertedIndex, model.numStateBytes);
Value allocated = LLVM::CallOp::create(rewriter, loc, mallocFunc.value(),
ValueRange{numStateBytes})
.getResult();
Value zero =
LLVM::ConstantOp::create(rewriter, loc, rewriter.getI8Type(), 0);
LLVM::MemsetOp::create(rewriter, loc, allocated, zero, numStateBytes,
false);
Value allocated;

if (useRuntime) {
auto ptrTy = LLVM::LLVMPointerType::get(getContext());

Value runtimeArgs;
if (op.getRuntimeArgs().has_value()) {
SmallVector<int8_t> argStringVec(op.getRuntimeArgsAttr().begin(),
op.getRuntimeArgsAttr().end());
argStringVec.push_back('\0');
auto strAttr = mlir::DenseElementsAttr::get(
mlir::RankedTensorType::get({(int64_t)argStringVec.size()},
rewriter.getI8Type()),
llvm::ArrayRef(argStringVec));

auto arrayCst = LLVM::ConstantOp::create(
rewriter, loc,
LLVM::LLVMArrayType::get(rewriter.getI8Type(), argStringVec.size()),
strAttr);
auto cst1 = LLVM::ConstantOp::create(rewriter, loc,
rewriter.getI32IntegerAttr(1));
runtimeArgs = LLVM::AllocaOp::create(rewriter, loc, ptrTy,
arrayCst.getType(), cst1);
LLVM::LifetimeStartOp::create(rewriter, loc, runtimeArgs);
LLVM::StoreOp::create(rewriter, loc, arrayCst, runtimeArgs);
} else {
runtimeArgs = LLVM::ZeroOp::create(rewriter, loc, ptrTy).getResult();
}
auto rtModelPtr = LLVM::AddressOfOp::create(rewriter, loc, ptrTy,
op.getRuntimeModelAttr())
.getResult();
allocated =
LLVM::CallOp::create(rewriter, loc, {ptrTy},
runtime::APICallbacks::symNameAllocInstance,
{rtModelPtr, runtimeArgs})
.getResult();

if (op.getRuntimeArgs().has_value())
LLVM::LifetimeEndOp::create(rewriter, loc, runtimeArgs);

} else {
FailureOr<LLVM::LLVMFuncOp> mallocFunc =
LLVM::lookupOrCreateMallocFn(rewriter, moduleOp, convertedIndex);
if (failed(mallocFunc))
return mallocFunc;

Value numStateBytes = LLVM::ConstantOp::create(
rewriter, loc, convertedIndex, model.numStateBytes);
allocated = LLVM::CallOp::create(rewriter, loc, mallocFunc.value(),
ValueRange{numStateBytes})
.getResult();
Value zero =
LLVM::ConstantOp::create(rewriter, loc, rewriter.getI8Type(), 0);
LLVM::MemsetOp::create(rewriter, loc, allocated, zero, numStateBytes,
false);
}

// Call the model's 'initial' function if present.
if (model.initialFnSymbol) {
Expand All @@ -426,10 +469,21 @@ struct SimInstantiateOpLowering
ValueRange{allocated});
}

LLVM::CallOp::create(rewriter, loc, freeFunc.value(),
ValueRange{allocated});
rewriter.eraseOp(op);
if (useRuntime) {
LLVM::CallOp::create(rewriter, loc, TypeRange{},
runtime::APICallbacks::symNameDeleteInstance,
{allocated});
} else {
FailureOr<LLVM::LLVMFuncOp> freeFunc =
LLVM::lookupOrCreateFreeFn(rewriter, moduleOp);
if (failed(freeFunc))
return freeFunc;

LLVM::CallOp::create(rewriter, loc, freeFunc.value(),
ValueRange{allocated});
}

rewriter.eraseOp(op);
return success();
}
};
Expand Down Expand Up @@ -679,6 +733,86 @@ static LogicalResult convert(arc::ExecuteOp op, arc::ExecuteOp::Adaptor adaptor,
return success();
}

//===----------------------------------------------------------------------===//
// Runtime Implementation
//===----------------------------------------------------------------------===//

struct RuntimeModelOpLowering
: public OpConversionPattern<arc::RuntimeModelOp> {
using OpConversionPattern::OpConversionPattern;

static constexpr uint64_t runtimeApiVersion = ARC_RUNTIME_API_VERSION;

LogicalResult
matchAndRewrite(arc::RuntimeModelOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {

auto modelInfoStructType = LLVM::LLVMStructType::getLiteral(
getContext(), {rewriter.getI64Type(), rewriter.getI64Type(),
LLVM::LLVMPointerType::get(getContext())});
static_assert(sizeof(ArcRuntimeModelInfo) == 24 &&
"Unexpected size of ArcRuntimeModelInfo struct");

// Construct the Model Name String
rewriter.setInsertionPoint(op);
SmallVector<char, 16> modNameArray(op.getName().begin(),
op.getName().end());
modNameArray.push_back('\0');
auto nameGlobalType =
LLVM::LLVMArrayType::get(rewriter.getI8Type(), modNameArray.size());
SmallString<16> globalSymName{"_arc_mod_name_"};
globalSymName.append(op.getName());
auto nameGlobal = LLVM::GlobalOp::create(
rewriter, op.getLoc(), nameGlobalType, /*isConstant=*/true,
LLVM::Linkage::Internal,
/*name=*/globalSymName, rewriter.getStringAttr(modNameArray),
/*alignment=*/0);

// Construct the Model Info Struct

auto modInfoGlobalOp =
LLVM::GlobalOp::create(rewriter, op.getLoc(), modelInfoStructType,
/*isConstant=*/false, LLVM::Linkage::External,
op.getSymName(), Attribute{});

Region &initRegion = modInfoGlobalOp.getInitializerRegion();
Block *initBlock = rewriter.createBlock(&initRegion);
rewriter.setInsertionPointToStart(initBlock);

auto apiVersionCst = LLVM::ConstantOp::create(
rewriter, op.getLoc(), rewriter.getI64IntegerAttr(runtimeApiVersion));
auto numStateBytesCst = LLVM::ConstantOp::create(rewriter, op.getLoc(),
op.getNumStateBytesAttr());
auto nameAddr =
LLVM::AddressOfOp::create(rewriter, op.getLoc(), nameGlobal);

Value initStruct =
LLVM::PoisonOp::create(rewriter, op.getLoc(), modelInfoStructType);

// Field: uint64_t apiVersion
initStruct = LLVM::InsertValueOp::create(
rewriter, op.getLoc(), initStruct, apiVersionCst, ArrayRef<int64_t>{0});
static_assert(offsetof(ArcRuntimeModelInfo, apiVersion) == 0,
"Unexpected offset of field apiVersion");
// Field: uint64_t numStateBytes
initStruct =
LLVM::InsertValueOp::create(rewriter, op.getLoc(), initStruct,
numStateBytesCst, ArrayRef<int64_t>{1});
static_assert(offsetof(ArcRuntimeModelInfo, numStateBytes) == 8,
"Unexpected offset of field numStateBytes");
// Field: const char *modelName
initStruct = LLVM::InsertValueOp::create(rewriter, op.getLoc(), initStruct,
nameAddr, ArrayRef<int64_t>{2});
static_assert(offsetof(ArcRuntimeModelInfo, modelName) == 16,
"Unexpected offset of field modelName");

LLVM::ReturnOp::create(rewriter, op.getLoc(), initStruct);

rewriter.replaceOp(op, modInfoGlobalOp);
return success();
}
};

//===----------------------------------------------------------------------===//
// Pass Implementation
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -788,6 +922,7 @@ void LowerArcToLLVMPass::runOnOperation() {
ModelOpLowering,
ReplaceOpWithInputPattern<seq::ToClockOp>,
ReplaceOpWithInputPattern<seq::FromClockOp>,
RuntimeModelOpLowering,
SeqConstClockLowering,
SimEmitValueOpLowering,
StateReadOpLowering,
Expand All @@ -798,14 +933,9 @@ void LowerArcToLLVMPass::runOnOperation() {
// clang-format on
patterns.add<ExecuteOp>(convert);

SmallVector<ModelInfo> models;
if (failed(collectModels(getOperation(), models))) {
signalPassFailure();
return;
}

llvm::DenseMap<StringRef, ModelInfoMap> modelMap(models.size());
for (ModelInfo &modelInfo : models) {
auto &modelInfo = getAnalysis<ModelInfoAnalysis>();
llvm::DenseMap<StringRef, ModelInfoMap> modelMap(modelInfo.infoMap.size());
for (auto &[_, modelInfo] : modelInfo.infoMap) {
llvm::DenseMap<StringRef, StateInfo> states(modelInfo.states.size());
for (StateInfo &stateInfo : modelInfo.states)
states.insert({stateInfo.name, stateInfo});
Expand Down
49 changes: 46 additions & 3 deletions lib/Dialect/Arc/ArcOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,19 @@ void SimInstantiateOp::print(OpAsmPrinter &p) {
p << " " << modelType.getModel() << " as ";
p.printRegionArgument(modelArg, {}, true);

p.printOptionalAttrDictWithKeyword(getOperation()->getAttrs());
if (getRuntimeModel() || getRuntimeArgs()) {
p << " runtime ";
if (getRuntimeModel())
p << getRuntimeModelAttr();
p << "(";
if (getRuntimeArgs())
p << getRuntimeArgsAttr();
p << ")";
}

p.printOptionalAttrDictWithKeyword(
getOperation()->getAttrs(),
{getRuntimeModelAttrName(), getRuntimeArgsAttrName()});

p << " ";

Expand All @@ -520,6 +532,24 @@ ParseResult SimInstantiateOp::parse(OpAsmParser &parser,
if (failed(parser.parseArgument(modelArg, false, false)))
return failure();

if (succeeded(parser.parseOptionalKeyword("runtime"))) {
StringAttr runtimeSym;
StringAttr runtimeArgs;
auto symOpt = parser.parseOptionalSymbolName(runtimeSym);
if (parser.parseLParen())
return failure();
auto nameOpt = parser.parseOptionalAttribute(runtimeArgs);
if (parser.parseRParen())
return failure();
if (succeeded(symOpt))
result.addAttribute(
SimInstantiateOp::getRuntimeModelAttrName(result.name),
FlatSymbolRefAttr::get(runtimeSym));
if (nameOpt.has_value())
result.addAttribute(SimInstantiateOp::getRuntimeArgsAttrName(result.name),
runtimeArgs);
}

if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
return failure();

Expand Down Expand Up @@ -547,15 +577,28 @@ LogicalResult SimInstantiateOp::verifyRegions() {

LogicalResult
SimInstantiateOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
bool failed = false;
Operation *moduleOp = getSupportedModuleOp(
symbolTable, getOperation(),
llvm::cast<SimModelInstanceType>(getBody().getArgument(0).getType())
.getModel()
.getAttr());
if (!moduleOp)
return failure();
failed = true;

if (getRuntimeModel().has_value()) {
Operation *runtimeModelOp = symbolTable.lookupNearestSymbolFrom(
getOperation(), getRuntimeModelAttr());
if (!runtimeModelOp) {
emitOpError("runtime model not found");
failed = true;
} else if (!isa<RuntimeModelOp>(runtimeModelOp)) {
emitOpError("referenced runtime model is not a RuntimeModelOp");
failed = true;
}
}

return success();
return success(!failed);
}

//===----------------------------------------------------------------------===//
Expand Down
Loading