package org.scribble.ast; import java.util.Collections; import java.util.LinkedList; import java.util.List; import java.util.function.Function; import java.util.function.Predicate; import java.util.stream.Collectors; import org.antlr.runtime.tree.CommonTree; import org.scribble.ast.global.GProtocolDecl; import org.scribble.ast.local.LProtocolDecl; import org.scribble.del.ScribDel; import org.scribble.main.ScribbleException; import org.scribble.sesstype.kind.Global; import org.scribble.sesstype.kind.Kind; import org.scribble.sesstype.kind.ProtocolKind; import org.scribble.sesstype.name.DataType; import org.scribble.sesstype.name.MessageSigName; import org.scribble.sesstype.name.ModuleName; import org.scribble.sesstype.name.ProtocolName; import org.scribble.util.ScribUtil; import org.scribble.visit.AstVisitor; public class Module extends ScribNodeBase { public final ModuleDecl moddecl; // Using (implicitly bounded) nested wildcards for mixed element lists (better practice to use separate lists?) private final List<ImportDecl<?>> imports; private final List<NonProtocolDecl<?>> data; private final List<ProtocolDecl<?>> protos; public Module(CommonTree source, ModuleDecl moddecl, List<ImportDecl<?>> imports, List<NonProtocolDecl<?>> data, List<ProtocolDecl<?>> protos) { super(source); this.moddecl = moddecl; this.imports = new LinkedList<>(imports); this.data = new LinkedList<>(data); this.protos = new LinkedList<>(protos); } @Override protected Module copy() { return new Module(this.source, this.moddecl, this.imports, this.data, this.protos); } @Override public Module clone() { ModuleDecl moddecl = (ModuleDecl) this.moddecl.clone(); List<ImportDecl<?>> imports = ScribUtil.cloneList(this.imports); List<NonProtocolDecl<?>> data = ScribUtil.cloneList(this.data); List<ProtocolDecl<?>> protos = ScribUtil.cloneList(this.protos); return AstFactoryImpl.FACTORY.Module(this.source, moddecl, imports, data, protos); } public Module reconstruct(ModuleDecl moddecl, List<ImportDecl<?>> imports, List<NonProtocolDecl<?>> data, List<ProtocolDecl<?>> protos) { ScribDel del = del(); Module m = new Module(this.source, moddecl, imports, data, protos); m = (Module) m.del(del); return m; } @Override public Module visitChildren(AstVisitor nv) throws ScribbleException { ModuleDecl moddecl = (ModuleDecl) visitChild(this.moddecl, nv); // class equality check probably too restrictive List<ImportDecl<?>> imports = ScribNodeBase.visitChildListWithClassEqualityCheck(this, this.imports, nv); List<NonProtocolDecl<?>> data = ScribNodeBase.visitChildListWithClassEqualityCheck(this, this.data, nv); List<ProtocolDecl<?>> protos = ScribNodeBase.visitChildListWithClassEqualityCheck(this, this.protos, nv); return reconstruct(moddecl, imports, data, protos); } public ModuleName getFullModuleName() { return this.moddecl.getFullModuleName(); } @Override public String toString() { String s = moddecl.toString(); for (ImportDecl<? extends Kind> id : this.imports) { s += "\n" + id; } for (NonProtocolDecl<? extends Kind> dtd : this.data) { s += "\n" + dtd; } for (ProtocolDecl<? extends ProtocolKind> pd : this.protos) { s += "\n" + pd; } return s; } // ptn simple alias name public DataTypeDecl getDataTypeDecl(DataType simpname) // Simple name (as for getProtocolDecl) { for (NonProtocolDecl<? extends Kind> dtd : this.data) { if (dtd.isDataTypeDecl() && dtd.getDeclName().equals(simpname)) { return (DataTypeDecl) dtd; } } throw new RuntimeException("Data type not found: " + simpname); } // msn simple alias name public MessageSigNameDecl getMessageSigDecl(MessageSigName simpname) { for (NonProtocolDecl<?> dtd : this.data) { if (dtd instanceof MessageSigNameDecl && dtd.getDeclName().equals(simpname)) { return (MessageSigNameDecl) dtd; } } throw new RuntimeException("Message signature not found: " + simpname); } public List<ImportDecl<?>> getImportDecls() { return Collections.unmodifiableList(this.imports); } public List<NonProtocolDecl<?>> getNonProtocolDecls() { return Collections.unmodifiableList(this.data); } public List<ProtocolDecl<?>> getProtocolDecls() { return Collections.unmodifiableList(this.protos); } public List<GProtocolDecl> getGlobalProtocolDecls() { return getProtocolDecls(IS_GLOBALPROTOCOLDECL, TO_GLOBALPROTOCOLDECL); } public List<LProtocolDecl> getLocalProtocolDecls() { return getProtocolDecls(IS_LOCALPROTOCOLDECL, TO_LOCALPROTOCOLDECL); } private <T extends ProtocolDecl<?>> List<T> getProtocolDecls(Predicate<ProtocolDecl<?>> filter, Function<ProtocolDecl<?>, T> cast) { return this.protos.stream().filter(filter).map(cast).collect(Collectors.toList()); } public <K extends ProtocolKind> boolean hasProtocolDecl(ProtocolName<K> simpname) { return hasProtocolDecl(this.protos, simpname); } // pn is simple name // separate into global/local? public <K extends ProtocolKind> ProtocolDecl<K> getProtocolDecl(ProtocolName<K> simpname) { return getProtocolDecl(this.protos, simpname); } private static <K extends ProtocolKind> boolean hasProtocolDecl(List<ProtocolDecl<?>> pds, ProtocolName<K> simpname) { return pds.stream() .filter((pd) -> pd.header.getDeclName().equals(simpname) && (simpname.getKind().equals(Global.KIND)) ? pd.isGlobal() : pd.isLocal()) .count() > 0; } // pn is simple name private static <K extends ProtocolKind> ProtocolDecl<K> getProtocolDecl(List<ProtocolDecl<?>> pds, ProtocolName<K> simpname) { List<ProtocolDecl<?>> filtered = pds.stream() .filter((pd) -> pd.header.getDeclName().equals(simpname) && (simpname.getKind().equals(Global.KIND)) ? pd.isGlobal() : pd.isLocal()) .collect(Collectors.toList()); if (filtered.size() == 0) { throw new RuntimeException("Protocol not found: " + simpname); } /*if (filtered.size() > 1) { throw new RuntimeException("Found duplicate protocol decls: " + simpname); // Just return first -- allows Do/DoArgListDel name disambiguation to go through, and later caught on leaving Module }*/ @SuppressWarnings("unchecked") ProtocolDecl<K> res = (ProtocolDecl<K>) filtered.get(0); return res; } private static final Predicate<ProtocolDecl<?>> IS_GLOBALPROTOCOLDECL = (pd) -> pd.isGlobal(); private static final Predicate<ProtocolDecl<?>> IS_LOCALPROTOCOLDECL = (pd) -> pd.isLocal(); private static final Function <ProtocolDecl<?>, GProtocolDecl> TO_GLOBALPROTOCOLDECL = (pd) -> (GProtocolDecl) pd; private static final Function <ProtocolDecl<?>, LProtocolDecl> TO_LOCALPROTOCOLDECL = (pd) -> (LProtocolDecl) pd; }