10#include "mlir/IR/Builders.h"
11#include "mlir/IR/Dominance.h"
17using namespace mlir::abi;
44bool needsRewrite(
const FunctionClassification &fc) {
49 if ((fc.returnInfo.kind != ArgKind::Direct) || fc.returnInfo.coercedType)
51 for (
const ArgClassification &ac : fc.argInfos)
52 if ((ac.kind != ArgKind::Direct) || ac.coercedType)
69 if (ac.kind != ArgKind::Direct || !ac.coercedType || !ac.canFlatten)
71 auto recTy = dyn_cast<cir::RecordType>(ac.coercedType);
72 if (!recTy || !recTy.isStruct() || recTy.getNumElements() <= 1)
84 const FunctionClassification &fc,
85 SmallVectorImpl<mlir::Type> &newArgTypes,
86 function_ref<mlir::InFlightDiagnostic()> emitError) {
87 assert(newArgTypes.empty() &&
"expected an empty output vector");
88 newArgTypes.reserve(oldArgTypes.size());
89 for (
auto [idx, ac] : llvm::enumerate(fc.argInfos)) {
90 mlir::Type origTy = oldArgTypes[idx];
98 llvm::append_range(newArgTypes, flatTy.getMembers());
104 newArgTypes.push_back(ac.coercedType ? ac.coercedType : origTy);
107 case ArgKind::Ignore:
109 case ArgKind::Expand: {
113 auto recTy = cast<cir::RecordType>(origTy);
114 assert(recTy.isStruct() &&
115 "Expand classification requires a struct type, not a union");
116 assert(!recTy.getMembers().empty() &&
117 "Expand classification requires at least one struct field");
118 llvm::append_range(newArgTypes, recTy.getMembers());
121 case ArgKind::Extend:
128 newArgTypes.push_back(origTy);
130 case ArgKind::Indirect:
136 newArgTypes.push_back(cir::PointerType::get(origTy));
140 return mlir::success();
148computeNewReturnType(mlir::Type origRetTy,
const ArgClassification &retInfo,
149 mlir::MLIRContext *ctx,
150 function_ref<mlir::InFlightDiagnostic()> emitError) {
151 switch (retInfo.kind) {
152 case ArgKind::Direct:
155 return retInfo.coercedType ? retInfo.coercedType : origRetTy;
156 case ArgKind::Ignore:
157 return cir::VoidType::get(ctx);
158 case ArgKind::Expand:
159 emitError() <<
"Expand return is not allowed (classic codegen rejects "
160 <<
"it in EmitFunctionEpilog)";
162 case ArgKind::Extend:
167 case ArgKind::Indirect:
172 return cir::VoidType::get(ctx);
174 llvm_unreachable(
"all ArgKind cases handled");
182mlir::Value createIgnoredValue(mlir::OpBuilder &builder, mlir::Location loc,
184 return cir::ConstantOp::create(builder, loc, ty, cir::PoisonAttr::get(ty));
192mlir::ArrayAttr updateArgAttrs(mlir::MLIRContext *ctx,
194 mlir::ArrayAttr existingArgAttrs,
195 const FunctionClassification &fc) {
196 mlir::Builder builder(ctx);
198 newArgAttrs.reserve(fc.argInfos.size());
199 for (
auto [oldIdx, ac] : llvm::enumerate(fc.argInfos)) {
200 if (ac.kind == ArgKind::Ignore)
202 mlir::DictionaryAttr existing = builder.getDictionaryAttr({});
203 if (existingArgAttrs && oldIdx < existingArgAttrs.size())
204 existing = mlir::cast<mlir::DictionaryAttr>(existingArgAttrs[oldIdx]);
208 newArgAttrs.append(flatTy.getNumElements(),
209 builder.getDictionaryAttr({}));
210 }
else if (ac.kind == ArgKind::Expand) {
213 auto recTy = cast<cir::RecordType>(origArgTypes[oldIdx]);
214 newArgAttrs.append(recTy.getNumElements(), builder.getDictionaryAttr({}));
215 }
else if (ac.kind == ArgKind::Extend) {
216 StringRef attrName = ac.signExtend ?
"llvm.signext" :
"llvm.zeroext";
218 attrs.push_back(builder.getNamedAttr(attrName, builder.getUnitAttr()));
219 newArgAttrs.push_back(builder.getDictionaryAttr(attrs));
220 }
else if (ac.kind == ArgKind::Indirect) {
237 mlir::Type pointeeTy = origArgTypes[oldIdx];
238 StringRef ownershipAttr = ac.byVal ?
"llvm.byval" :
"llvm.byref";
240 attrs.push_back(builder.getNamedAttr(
241 "llvm.align", builder.getI64IntegerAttr(ac.indirectAlign.value())));
243 builder.getNamedAttr(ownershipAttr, mlir::TypeAttr::get(pointeeTy)));
246 builder.getNamedAttr(
"llvm.noalias", builder.getUnitAttr()));
248 builder.getNamedAttr(
"llvm.noundef", builder.getUnitAttr()));
250 newArgAttrs.push_back(builder.getDictionaryAttr(attrs));
252 newArgAttrs.push_back(existing);
255 return builder.getArrayAttr(newArgAttrs);
261mlir::ArrayAttr updateResAttrs(mlir::MLIRContext *ctx,
262 mlir::ArrayAttr existingResAttrs,
263 const ArgClassification &retInfo) {
264 if (retInfo.kind != ArgKind::Extend)
265 return existingResAttrs;
268 if (existingResAttrs && !existingResAttrs.empty())
269 for (mlir::NamedAttribute na :
270 mlir::cast<mlir::DictionaryAttr>(existingResAttrs[0]))
272 StringRef attrName = retInfo.signExtend ?
"llvm.signext" :
"llvm.zeroext";
273 attrs.push_back(mlir::NamedAttribute(mlir::StringAttr::get(ctx, attrName),
274 mlir::UnitAttr::get(ctx)));
275 return mlir::ArrayAttr::get(ctx, {mlir::DictionaryAttr::get(ctx, attrs)});
303emitCoercionToMemory(mlir::OpBuilder &builder, mlir::Location loc,
304 mlir::Type dstTy, mlir::Value src,
305 mlir::FunctionOpInterface funcOp,
306 const mlir::DataLayout &dl,
307 SmallPtrSetImpl<mlir::Operation *> &createdOps) {
308 mlir::Type srcTy = src.getType();
309 assert(srcTy != dstTy &&
310 "emitCoercion callers must pre-check that the types differ");
312 uint64_t srcAlign = dl.getTypeABIAlignment(srcTy);
313 uint64_t dstAlign = dl.getTypeABIAlignment(dstTy);
314 uint64_t allocaAlign = std::max(srcAlign, dstAlign);
316 dl.getTypeSize(srcTy) >= dl.getTypeSize(dstTy) ? srcTy : dstTy;
318 auto slotPtrTy = cir::PointerType::get(slotTy);
319 auto srcPtrTy = cir::PointerType::get(srcTy);
320 auto dstPtrTy = cir::PointerType::get(dstTy);
322 cir::AllocaOp alloca;
324 mlir::OpBuilder::InsertionGuard guard(builder);
325 mlir::Block &entry = funcOp->getRegion(0).front();
326 builder.setInsertionPointToStart(&entry);
327 alloca = cir::AllocaOp::create(builder, loc, slotPtrTy,
328 builder.getStringAttr(
"coerce"),
329 builder.getI64IntegerAttr(allocaAlign));
331 createdOps.insert(alloca);
334 mlir::Value srcSlot = alloca;
335 if (slotTy != srcTy) {
336 auto srcCast = cir::CastOp::create(builder, loc, srcPtrTy,
337 cir::CastKind::bitcast, alloca);
338 createdOps.insert(srcCast);
341 auto store = cir::StoreOp::create(builder, loc, src, srcSlot);
342 createdOps.insert(store);
345 if (slotTy != dstTy) {
346 auto dstCast = cir::CastOp::create(builder, loc, dstPtrTy,
347 cir::CastKind::bitcast, alloca);
348 createdOps.insert(dstCast);
357mlir::Value emitCoercion(mlir::OpBuilder &builder, mlir::Location loc,
358 mlir::Type dstTy, mlir::Value src,
359 mlir::FunctionOpInterface funcOp,
360 const mlir::DataLayout &dl,
361 SmallPtrSetImpl<mlir::Operation *> &createdOps) {
362 mlir::Value dstSlot =
363 emitCoercionToMemory(builder, loc, dstTy, src, funcOp, dl, createdOps);
364 auto load = cir::LoadOp::create(builder, loc, dstSlot);
365 createdOps.insert(load);
371mlir::Value emitCoercion(mlir::OpBuilder &builder, mlir::Location loc,
372 mlir::Type dstTy, mlir::Value src,
373 mlir::FunctionOpInterface funcOp,
374 const mlir::DataLayout &dl) {
375 SmallPtrSet<mlir::Operation *, 4> ignored;
376 return emitCoercion(builder, loc, dstTy, src, funcOp, dl, ignored);
381void insertReturnCoercion(mlir::FunctionOpInterface funcOp,
382 mlir::Type origRetTy, mlir::Type coercedRetTy,
383 mlir::OpBuilder &builder,
384 const mlir::DataLayout &dl) {
386 funcOp.walk([&](cir::ReturnOp r) { returns.push_back(r); });
387 for (cir::ReturnOp r : returns) {
388 if (r.getInput().empty())
390 mlir::Value origVal = r.getInput()[0];
391 if (origVal.getType() == coercedRetTy)
393 builder.setInsertionPoint(r);
394 mlir::Value coerced =
395 emitCoercion(builder, r.getLoc(), coercedRetTy, origVal, funcOp, dl);
396 r->setOperand(0, coerced);
412emitStructFieldArgs(mlir::OpBuilder &builder, mlir::Location loc,
414 SmallVectorImpl<mlir::Value> &newArgs,
415 SmallVectorImpl<cir::LoadOp> &replacedWholeLoads) {
416 cir::LoadOp wholeLoad = structVal.getDefiningOp<cir::LoadOp>();
417 cir::AllocaOp srcAlloca;
418 if (wholeLoad && !wholeLoad.getIsVolatile() && !wholeLoad.getMemOrder())
419 srcAlloca = wholeLoad.getAddr().getDefiningOp<cir::AllocaOp>();
422 mlir::OpBuilder::InsertionGuard guard(builder);
423 builder.setInsertionPoint(wholeLoad);
424 for (
auto [f, fieldTy] : llvm::enumerate(recTy.
getMembers())) {
425 mlir::Type fieldPtrTy = cir::PointerType::get(fieldTy);
426 mlir::Value fieldPtr = cir::GetMemberOp::create(
427 builder, loc, fieldPtrTy, srcAlloca,
"", f);
428 newArgs.push_back(cir::LoadOp::create(builder, loc, fieldPtr));
430 replacedWholeLoads.push_back(wholeLoad);
434 cir::ExtractMemberOp::create(builder, loc, structVal, f));
452void insertArgCoercion(mlir::FunctionOpInterface funcOp,
453 const FunctionClassification &fc,
454 mlir::OpBuilder &builder,
const mlir::DataLayout &dl,
456 mlir::Region &body = funcOp->getRegion(0);
459 mlir::Block &entry = body.front();
465 unsigned blockArgIdx = hasSRetArg ? 1 : 0;
467 for (
const ArgClassification &ac : fc.argInfos) {
468 assert(blockArgIdx < entry.getNumArguments() &&
469 "classification count must not exceed entry block arguments");
471 if (ac.kind == ArgKind::Expand) {
475 mlir::BlockArgument origArg = entry.getArgument(blockArgIdx);
476 auto recTy = cast<cir::RecordType>(origArg.getType());
478 "Expand classification requires a struct type, not a union");
480 assert(numFields > 0 &&
481 "Expand classification requires at least one struct field");
482 mlir::Location loc = funcOp.getLoc();
492 cir::StoreOp paramStore;
493 cir::AllocaOp destAlloca;
494 if (!origArg.use_empty()) {
495 assert(origArg.hasOneUse() &&
496 "Expand arg must have exactly one use (the CIRGen param spill)");
497 paramStore = cast<cir::StoreOp>(*origArg.user_begin());
498 assert(paramStore.getValue() == origArg &&
499 "Expand arg's use must be the value operand of its store");
500 destAlloca = cast<cir::AllocaOp>(paramStore.getAddr().getDefiningOp());
507 mlir::Operation *fieldStoreInsertPt =
nullptr;
509 fieldStoreInsertPt = paramStore->getNextNode();
510 assert(fieldStoreInsertPt &&
511 "param spill must be followed by a block terminator");
523 builder.setInsertionPoint(fieldStoreInsertPt);
524 for (
auto [f, fieldTy] : llvm::enumerate(recTy.
getMembers())) {
526 origArg.setType(fieldTy);
528 entry.insertArgument(blockArgIdx + f, fieldTy, loc);
531 mlir::Type fieldPtrTy = cir::PointerType::get(fieldTy);
532 auto fieldPtr = cir::GetMemberOp::create(builder, loc, fieldPtrTy,
535 cir::StoreOp::create(builder, loc, entry.getArgument(blockArgIdx + f),
539 blockArgIdx += numFields;
543 mlir::BlockArgument blockArg = entry.getArgument(blockArgIdx);
552 unsigned numFields = flatTy.getNumElements();
553 assert(numFields >= 2 &&
"getFlattenedCoercedType guarantees >1 fields");
554 Type origTy = blockArg.getType();
555 Location loc = funcOp.getLoc();
558 blockArg.setType(flatTy.getElementType(0));
559 for (
unsigned f = 1; f < numFields; ++f)
560 entry.insertArgument(blockArgIdx + f, flatTy.getElementType(f), loc);
563 builder.setInsertionPointToStart(&entry);
564 auto flatPtrTy = cir::PointerType::get(flatTy);
565 uint64_t flatAlign = dl.getTypeABIAlignment(flatTy);
566 auto flatSlot = cir::AllocaOp::create(
567 builder, loc, flatPtrTy, builder.getStringAttr(
"coerce"),
568 builder.getI64IntegerAttr(flatAlign));
569 SmallPtrSet<Operation *, 8> flattenOps = {flatSlot};
570 for (
auto [f, fieldTy] : llvm::enumerate(flatTy.getMembers())) {
571 Type fieldPtrTy = cir::PointerType::get(fieldTy);
572 auto fieldPtr = cir::GetMemberOp::create(builder, loc, fieldPtrTy,
575 flattenOps.insert(fieldPtr);
576 auto storeOp = cir::StoreOp::create(
577 builder, loc, entry.getArgument(blockArgIdx + f), fieldPtr);
578 flattenOps.insert(storeOp);
581 cir::LoadOp::create(builder, loc, flatTy, flatSlot.getResult());
582 flattenOps.insert(flatLoaded);
586 Value finalVal = flatLoaded;
587 if (origTy != flatTy) {
588 SmallPtrSet<Operation *, 4> coercionOps;
589 finalVal = emitCoercion(builder, loc, origTy, flatLoaded, funcOp, dl,
591 flattenOps.insert(coercionOps.begin(), coercionOps.end());
596 blockArg.replaceAllUsesExcept(finalVal, flattenOps);
598 blockArgIdx += numFields;
602 if (ac.kind == ArgKind::Direct && ac.coercedType) {
603 mlir::Type oldArgTy = blockArg.getType();
604 mlir::Type newArgTy = ac.coercedType;
605 if (oldArgTy == newArgTy) {
609 blockArg.setType(newArgTy);
611 builder.setInsertionPointToStart(&entry);
612 SmallPtrSet<mlir::Operation *, 4> coercionOps;
613 mlir::Value adapted = emitCoercion(builder, funcOp.getLoc(), oldArgTy,
614 blockArg, funcOp, dl, coercionOps);
620 blockArg.replaceAllUsesExcept(adapted, coercionOps);
621 }
else if (ac.kind == ArgKind::Indirect) {
627 mlir::Type origTy = blockArg.getType();
628 auto ptrTy = cir::PointerType::get(origTy);
629 blockArg.setType(ptrTy);
631 builder.setInsertionPointToStart(&entry);
632 auto loadOp = cir::LoadOp::create(builder, funcOp.getLoc(), blockArg);
633 SmallPtrSet<mlir::Operation *, 1> loadOps = {loadOp};
634 blockArg.replaceAllUsesExcept(loadOp.getResult(), loadOps);
670void insertSRetStores(mlir::FunctionOpInterface funcOp, mlir::Type origRetTy,
671 mlir::OpBuilder &builder) {
672 mlir::Value sretPtr = funcOp.getArguments()[0];
675 funcOp->walk([&](cir::ReturnOp retOp) { returnOps.push_back(retOp); });
677 cir::AllocaOp retAlloca =
nullptr;
678 for (cir::ReturnOp retOp : returnOps) {
681 assert(!retOp.getInput().empty() &&
682 "cir.return in sret function must have an operand");
684 cir::LoadOp retLoad =
685 mlir::cast<cir::LoadOp>(retOp.getInput()[0].getDefiningOp());
693 retAlloca = mlir::cast<cir::AllocaOp>(retLoad.getAddr().getDefiningOp());
694 retAlloca.getResult().replaceAllUsesWith(sretPtr);
700 builder.setInsertionPoint(retOp);
701 cir::ReturnOp::create(builder, retOp.getLoc());
703 if (retLoad.use_empty())
722 builder.getNamedAttr(
"llvm.sret", mlir::TypeAttr::get(retTy)));
724 builder.getNamedAttr(
"llvm.align", builder.getI64IntegerAttr(align)));
727 builder.getNamedAttr(
"llvm.noalias", builder.getUnitAttr()));
728 attrs.push_back(builder.getNamedAttr(
"llvm.writable", builder.getUnitAttr()));
730 builder.getNamedAttr(
"llvm.dead_on_unwind", builder.getUnitAttr()));
740void applySretSlotAttrs(cir::CallOp newCall, mlir::ArrayAttr argAttrs,
741 mlir::Type retTy, uint64_t align,
742 mlir::OpBuilder &builder) {
743 mlir::MLIRContext *ctx = newCall->getContext();
745 buildSretSlotAttrs(builder, retTy, align,
false);
748 newArgAttrs.reserve(newCall.getArgOperands().size());
749 newArgAttrs.push_back(mlir::DictionaryAttr::get(ctx, sretAttrs));
751 llvm::append_range(newArgAttrs, argAttrs);
752 assert(newArgAttrs.size() <= newCall.getArgOperands().size() &&
753 "arg_attrs wider than the rewritten call's operand list");
754 newArgAttrs.resize(newCall.getArgOperands().size(),
755 mlir::DictionaryAttr::get(ctx));
756 newCall->setAttr(
"arg_attrs", mlir::ArrayAttr::get(ctx, newArgAttrs));
766void rewriteIndirectReturnCall(cir::CallOp call,
767 const FunctionClassification &fc,
769 mlir::Type origRetTy,
771 mlir::OpBuilder &builder) {
772 mlir::MLIRContext *ctx = call->getContext();
773 auto ptrTy = cir::PointerType::get(origRetTy);
774 builder.setInsertionPoint(call);
775 uint64_t sretAlign = fc.returnInfo.indirectAlign.value();
788 mlir::Value sretSlot =
nullptr;
789 cir::StoreOp reuseStore =
nullptr;
790 if (call.getResult().hasOneUse()) {
791 mlir::Operation *user = *call.getResult().getUsers().begin();
792 if (
auto store = mlir::dyn_cast<cir::StoreOp>(user))
793 if (store.getValue() == call.getResult() &&
794 store.getAddr().getType() == ptrTy &&
795 mlir::DominanceInfo().properlyDominates(store.getAddr(), call)) {
796 sretSlot = store.getAddr();
801 auto alloca = cir::AllocaOp::create(
802 builder, call.getLoc(), ptrTy,
803 builder.getStringAttr(
"sret"),
804 builder.getI64IntegerAttr(sretAlign));
809 sretArgs.push_back(sretSlot);
810 sretArgs.append(newArgs.begin(), newArgs.end());
812 mlir::Type sretVoidTy = cir::VoidType::get(ctx);
813 auto newCall = cir::CallOp::create(
814 builder, call.getLoc(), call.getCalleeAttr(), sretVoidTy, sretArgs);
815 for (mlir::NamedAttribute attr : call->getAttrs())
816 if (!newCall->hasAttr(
attr.getName()))
817 newCall->setAttr(
attr.getName(),
attr.getValue());
824 mlir::ArrayAttr argAttrs = call->getAttrOfType<mlir::ArrayAttr>(
"arg_attrs");
825 bool needsArgAttrUpdate =
826 llvm::any_of(fc.argInfos, [](
const ArgClassification &ac) {
827 return ac.kind == ArgKind::Ignore || ac.kind == ArgKind::Extend ||
828 ac.kind == ArgKind::Indirect || ac.kind == ArgKind::Expand ||
829 getFlattenedCoercedType(ac);
831 if (needsArgAttrUpdate)
832 argAttrs = updateArgAttrs(ctx, origCallArgTypes, argAttrs, fc);
833 applySretSlotAttrs(newCall, argAttrs, origRetTy, sretAlign, builder);
841 builder.setInsertionPointAfter(newCall);
842 auto load = cir::LoadOp::create(builder, call.getLoc(), origRetTy, sretSlot,
847 cir::SyncScopeKindAttr(),
850 call.getResult().replaceAllUsesWith(load);
858 mlir::FunctionOpInterface funcOpInterface,
const FunctionClassification &fc,
859 mlir::OpBuilder &builder) {
865 cir::FuncOp funcOp = mlir::cast<cir::FuncOp>(funcOpInterface);
867 if (!needsRewrite(fc))
868 return mlir::success();
872 mlir::MLIRContext *ctx = funcOp->getContext();
877 assert(oldResultTypes.size() <= 1 &&
878 "CIR functions return zero or one value");
881 if (mlir::failed(buildNewArgTypes(oldArgTypes, fc, newArgTypes,
882 [&]() {
return funcOp.emitOpError(); })))
883 return mlir::failure();
885 mlir::Type voidTy = cir::VoidType::get(ctx);
886 mlir::Type origRetTy = oldResultTypes.empty() ? voidTy : oldResultTypes[0];
887 mlir::Type newRetTy = computeNewReturnType(
888 origRetTy, fc.returnInfo, ctx, [&]() { return funcOp.emitOpError(); });
890 return mlir::failure();
900 fc.returnInfo.kind == ArgKind::Indirect && !oldResultTypes.empty();
902 newArgTypes.insert(newArgTypes.begin(), cir::PointerType::get(origRetTy));
904 if (funcOp.isDefinition()) {
905 mlir::Region &body = funcOp->getRegion(0);
911 body.front().insertArgument(0u, cir::PointerType::get(origRetTy),
913 insertSRetStores(funcOp, origRetTy, builder);
923 insertArgCoercion(funcOp, fc, builder, dl, hasSRet);
928 if (fc.returnInfo.kind == ArgKind::Direct && fc.returnInfo.coercedType &&
929 !oldResultTypes.empty() && fc.returnInfo.coercedType != origRetTy)
930 insertReturnCoercion(funcOp, origRetTy, fc.returnInfo.coercedType,
933 mlir::Block &entry = body.front();
942 unsigned blockArgIdx = hasSRet ? 1 : 0;
943 for (
auto [i, ac] : llvm::enumerate(fc.argInfos)) {
944 if (blockArgIdx >= entry.getNumArguments())
946 if (ac.kind == ArgKind::Ignore) {
947 mlir::BlockArgument arg = entry.getArgument(blockArgIdx);
948 if (!arg.use_empty()) {
949 builder.setInsertionPointToStart(&entry);
951 createIgnoredValue(builder, funcOp.getLoc(), arg.getType());
952 arg.replaceAllUsesWith(poison);
954 entry.eraseArgument(blockArgIdx);
958 blockArgIdx += flatTy.getNumElements();
959 else if (ac.kind == ArgKind::Expand)
960 blockArgIdx += cast<cir::RecordType>(oldArgTypes[i]).getNumElements();
972 if (fc.returnInfo.kind == ArgKind::Ignore && !oldResultTypes.empty()) {
973 assert(mlir::isa<cir::VoidType>(newRetTy) &&
974 "Ignore-return path requires the new return type to be void");
976 funcOp.walk([&](cir::ReturnOp r) { returns.push_back(r); });
977 for (cir::ReturnOp r : returns) {
978 if (r.getNumOperands() == 0)
980 builder.setInsertionPoint(r);
981 cir::ReturnOp::create(builder, r.getLoc());
987 mlir::Type newFnTy = funcOp.cloneTypeWith(newArgTypes, newResultTypes);
988 funcOp.setFunctionTypeAttr(mlir::TypeAttr::get(newFnTy));
995 bool needsArgAttrUpdate =
996 hasSRet || llvm::any_of(fc.argInfos, [](
const ArgClassification &ac) {
997 return ac.kind == ArgKind::Ignore || ac.kind == ArgKind::Extend ||
998 ac.kind == ArgKind::Indirect || ac.kind == ArgKind::Expand ||
999 getFlattenedCoercedType(ac);
1001 if (needsArgAttrUpdate) {
1002 auto existing = funcOp->getAttrOfType<mlir::ArrayAttr>(
"arg_attrs");
1003 mlir::ArrayAttr updated = updateArgAttrs(ctx, oldArgTypes, existing, fc);
1009 builder, origRetTy, fc.returnInfo.indirectAlign.value(),
1010 funcOp.isDefinition());
1012 withSret.push_back(mlir::DictionaryAttr::get(ctx, sretAttrs));
1013 llvm::append_range(withSret, updated);
1014 funcOp->setAttr(
"arg_attrs", mlir::ArrayAttr::get(ctx, withSret));
1016 funcOp->setAttr(
"arg_attrs", updated);
1022 if (fc.returnInfo.kind == ArgKind::Extend) {
1023 auto existing = funcOp->getAttrOfType<mlir::ArrayAttr>(
"res_attrs");
1024 funcOp->setAttr(
"res_attrs", updateResAttrs(ctx, existing, fc.returnInfo));
1027 return mlir::success();
1032 const FunctionClassification &fc,
1033 mlir::OpBuilder &builder) {
1034 if (!needsRewrite(fc))
1035 return mlir::success();
1037 if (mlir::isa<cir::TryCallOp>(callOp))
1038 return callOp->emitOpError()
1039 <<
"TryCallOp not yet implemented in CallConvLowering";
1041 auto call = mlir::cast<cir::CallOp>(callOp);
1042 if (call.isIndirect())
1043 return call.emitOpError()
1044 <<
"indirect call not yet implemented in CallConvLowering";
1046 mlir::MLIRContext *ctx = callOp->getContext();
1047 auto enclosingFunc = call->getParentOfType<mlir::FunctionOpInterface>();
1049 builder.setInsertionPoint(call);
1052 mlir::ValueRange argOperands = call.getArgOperands();
1053 newArgs.reserve(argOperands.size());
1064 llvm::append_range(origCallArgTypes, argOperands.getTypes());
1065 if (argOperands.size() > fc.argInfos.size())
1066 return call.emitOpError()
1067 <<
"variadic arguments not yet implemented in CallConvLowering";
1068 assert(fc.argInfos.size() == argOperands.size() &&
1069 "call operand count must match classified arg count");
1070 for (
auto [idx, ac] : llvm::enumerate(fc.argInfos)) {
1071 if (ac.kind == ArgKind::Ignore)
1073 mlir::Value arg = argOperands[idx];
1081 if (arg.getType() != flatTy) {
1082 SmallPtrSet<mlir::Operation *, 4> coercionOps;
1083 mlir::Value coercedPtr =
1084 emitCoercionToMemory(builder, call.getLoc(), flatTy, arg,
1085 enclosingFunc, dl, coercionOps);
1086 for (
auto [f, fieldTy] : llvm::enumerate(flatTy.getMembers())) {
1087 mlir::Type fieldPtrTy = cir::PointerType::get(fieldTy);
1089 cir::GetMemberOp::create(builder, call.getLoc(), fieldPtrTy,
1091 newArgs.push_back(cir::LoadOp::create(builder, call.getLoc(), fieldTy,
1092 fieldPtr.getResult()));
1095 emitStructFieldArgs(builder, call.getLoc(), arg, flatTy, newArgs,
1096 replacedWholeLoads);
1098 }
else if (ac.kind == ArgKind::Expand) {
1101 auto recTy = cast<cir::RecordType>(arg.getType());
1103 "Expand classification requires a struct type, not a union");
1104 emitStructFieldArgs(builder, call.getLoc(), arg, recTy, newArgs,
1105 replacedWholeLoads);
1106 }
else if (ac.kind == ArgKind::Direct && ac.coercedType &&
1107 arg.getType() != ac.coercedType) {
1108 arg = emitCoercion(builder, call.getLoc(), ac.coercedType, arg,
1110 newArgs.push_back(arg);
1111 }
else if (ac.kind == ArgKind::Indirect) {
1117 mlir::Type argTy = arg.getType();
1118 auto ptrTy = cir::PointerType::get(argTy);
1119 uint64_t align = ac.indirectAlign.value();
1120 StringRef slotName = ac.byVal ?
"byval" :
"byref";
1121 auto slot = cir::AllocaOp::create(builder, call.getLoc(), ptrTy,
1122 builder.getStringAttr(slotName),
1123 builder.getI64IntegerAttr(align));
1124 cir::StoreOp::create(builder, call.getLoc(), arg, slot);
1126 newArgs.push_back(arg);
1128 newArgs.push_back(arg);
1132 bool hasResult = call.getNumResults() > 0;
1133 mlir::Type origRetTy =
1134 hasResult ? call.getResult().getType() : cir::VoidType::get(ctx);
1140 if (fc.returnInfo.kind == ArgKind::Indirect && hasResult) {
1141 rewriteIndirectReturnCall(call, fc, newArgs, origRetTy, origCallArgTypes,
1143 return mlir::success();
1146 mlir::Type callRetTy = origRetTy;
1147 if (fc.returnInfo.kind == ArgKind::Ignore && hasResult)
1148 callRetTy = cir::VoidType::get(ctx);
1149 bool returnNeedsCoercion =
1150 hasResult && fc.returnInfo.kind == ArgKind::Direct &&
1151 fc.returnInfo.coercedType && fc.returnInfo.coercedType != origRetTy;
1152 if (returnNeedsCoercion)
1153 callRetTy = fc.returnInfo.coercedType;
1155 builder.setInsertionPoint(call);
1156 auto newCall = cir::CallOp::create(builder, call.getLoc(),
1157 call.getCalleeAttr(), callRetTy, newArgs);
1158 for (mlir::NamedAttribute attr : call->getAttrs())
1159 if (!newCall->hasAttr(attr.getName()))
1160 newCall->setAttr(attr.getName(), attr.getValue());
1164 if (returnNeedsCoercion) {
1165 builder.setInsertionPointAfter(newCall);
1166 mlir::Value coercedBack =
1167 emitCoercion(builder, call.getLoc(), origRetTy, newCall.getResult(),
1169 call.getResult().replaceAllUsesWith(coercedBack);
1176 bool needsArgAttrUpdate =
1177 llvm::any_of(fc.argInfos, [](
const ArgClassification &ac) {
1178 return ac.kind == ArgKind::Ignore || ac.kind == ArgKind::Extend ||
1179 ac.kind == ArgKind::Indirect || ac.kind == ArgKind::Expand ||
1180 getFlattenedCoercedType(ac);
1182 if (needsArgAttrUpdate) {
1183 auto existing = call->getAttrOfType<mlir::ArrayAttr>(
"arg_attrs");
1184 newCall->setAttr(
"arg_attrs",
1185 updateArgAttrs(ctx, origCallArgTypes, existing, fc));
1187 if (fc.returnInfo.kind == ArgKind::Extend) {
1188 auto existing = call->getAttrOfType<mlir::ArrayAttr>(
"res_attrs");
1189 newCall->setAttr(
"res_attrs", updateResAttrs(ctx, existing, fc.returnInfo));
1192 if (hasResult && fc.returnInfo.kind == ArgKind::Ignore) {
1197 if (!call.getResult().use_empty()) {
1198 builder.setInsertionPointAfter(newCall);
1199 mlir::Value poison =
1200 createIgnoredValue(builder, call.getLoc(), origRetTy);
1201 call.getResult().replaceAllUsesWith(poison);
1203 }
else if (hasResult && !returnNeedsCoercion) {
1205 call.getResult().replaceAllUsesWith(newCall.getResult());
1215 SmallPtrSet<mlir::Operation *, 4> erased;
1216 for (cir::LoadOp wholeLoad : replacedWholeLoads)
1217 if (erased.insert(wholeLoad).second && wholeLoad.use_empty())
1220 return mlir::success();
mlir::LogicalResult rewriteFunctionDefinition(mlir::FunctionOpInterface funcOp, const mlir::abi::FunctionClassification &fc, mlir::OpBuilder &builder) override
mlir::LogicalResult rewriteCallSite(mlir::Operation *callOp, const mlir::abi::FunctionClassification &fc, mlir::OpBuilder &builder) override
C++ view class that accepts both !cir.struct and !cir.union types.
llvm::ArrayRef< mlir::Type > getMembers() const
size_t getNumElements() const
const internal::VariadicAllOfMatcher< Attr > attr