package jef.database.dialect; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Savepoint; import java.sql.Statement; import java.sql.Types; import java.util.Arrays; import javax.persistence.PersistenceException; import jef.common.log.LogUtil; import jef.database.ConnectInfo; import jef.database.DbMetaData; import jef.database.DbUtils; import jef.database.ORMConfig; import jef.database.dialect.ColumnType.AutoIncrement; import jef.database.dialect.ColumnType.Clob; import jef.database.dialect.ColumnType.Varchar; import jef.database.dialect.handler.LimitHandler; import jef.database.dialect.handler.LimitOffsetLimitHandler; import jef.database.dialect.type.AutoIncrementMapping; import jef.database.exception.JDBCExceptionHelper; import jef.database.exception.TemplatedViolatedConstraintNameExtracter; import jef.database.exception.ViolatedConstraintNameExtracter; import jef.database.jdbc.JDBCTarget; import jef.database.jdbc.statement.DelegatingPreparedStatement; import jef.database.jdbc.statement.DelegatingStatement; import jef.database.jsqlparser.expression.BinaryExpression; import jef.database.jsqlparser.expression.Function; import jef.database.jsqlparser.expression.Interval; import jef.database.meta.DbProperty; import jef.database.meta.Feature; import jef.database.query.Func; import jef.database.query.Scientific; import jef.database.query.function.CastFunction; import jef.database.query.function.EmuDateAddSubByTimesatmpadd; import jef.database.query.function.EmuDatediffByTimestampdiff; import jef.database.query.function.EmuDecodeWithCase; import jef.database.query.function.EmuJDBCTimestampFunction; import jef.database.query.function.EmuLocateOnPostgres; import jef.database.query.function.EmuPostgreTimestampDiff; import jef.database.query.function.EmuPostgresAddDate; import jef.database.query.function.EmuPostgresExtract; import jef.database.query.function.EmuPostgresSubDate; import jef.database.query.function.NoArgSQLFunction; import jef.database.query.function.StandardSQLFunction; import jef.database.query.function.TemplateFunction; import jef.database.query.function.VarArgsSQLFunction; import jef.database.support.RDBMS; import jef.tools.StringUtils; import jef.tools.collection.CollectionUtils; import jef.tools.string.JefStringReader; public class PostgreSqlDialect extends AbstractDialect { protected static final String JDBC_URL_FORMAT = "jdbc:postgresql://%1$s:%2$s/%3$s"; protected static final int DEFAULT_PORT = 5432; public PostgreSqlDialect() { features = CollectionUtils.identityHashSet(); features.addAll(Arrays.asList(Feature.ALTER_FOR_EACH_COLUMN, Feature.COLUMN_ALTERATION_SYNTAX, Feature.SUPPORT_CONCAT, Feature.SUPPORT_SEQUENCE, Feature.SUPPORT_LIMIT, Feature.AI_TO_SEQUENCE_WITHOUT_DEFAULT, Feature.SUPPORT_COMMENT)); loadKeywords("postgresql_keywords.properties"); registerNative(Func.coalesce); registerAlias(Func.nvl, "coalesce"); registerNative(Scientific.cot); registerNative(Scientific.exp); registerNative(Scientific.ln, "log"); registerNative(new StandardSQLFunction("cbrt")); registerNative(Scientific.radians); registerNative(Scientific.degrees); registerNative(new StandardSQLFunction("stddev")); registerNative(new StandardSQLFunction("variance")); registerNative(new NoArgSQLFunction("random")); registerAlias(Scientific.rand, "random"); registerNative(Func.cast, new CastFunction()); registerNative(Func.mod); registerNative(Func.nullif); registerNative(Func.round); registerNative(Func.trunc); registerNative(Func.ceil); registerNative(Func.floor); registerNative(Func.translate); registerNative(new StandardSQLFunction("chr")); registerNative(Func.lower); registerNative(Func.upper); registerAlias("lcase", "lower"); registerAlias("ucase", "upper"); registerNative(new StandardSQLFunction("substr")); registerAlias(Func.substring, "substr"); registerNative(new StandardSQLFunction("initcap")); registerNative(new StandardSQLFunction("to_ascii")); registerNative(new StandardSQLFunction("quote_ident")); registerNative(new StandardSQLFunction("quote_literal")); registerNative(new StandardSQLFunction("md5")); registerNative(new StandardSQLFunction("ascii")); registerNative(new StandardSQLFunction("char_length")); registerAlias(Func.length, "char_length"); registerNative(new StandardSQLFunction("bit_length")); registerNative(new StandardSQLFunction("octet_length")); registerNative(new StandardSQLFunction("age"));// 单参数时计算当前时间与指定时间的差,双参数时计算第一个减去第二个时间 registerNative(new NoArgSQLFunction("current_date", false)); registerAlias(Func.current_date, "current_date"); registerNative(new NoArgSQLFunction("current_time", false)); registerAlias(Func.current_time, "current_time"); registerNative(new NoArgSQLFunction("current_timestamp", false), "now"); registerAlias(Func.current_timestamp, "current_timestamp"); registerAlias(Func.now, "current_timestamp"); registerAlias("sysdate", "current_timestamp"); registerNative(new StandardSQLFunction("date_trunc")); registerNative(new NoArgSQLFunction("localtime", false)); registerNative(new NoArgSQLFunction("localtimestamp", false)); registerNative(new NoArgSQLFunction("timeofday")); registerNative(new StandardSQLFunction("isfinite")); registerNative(Func.date); registerCompatible(Func.time, new TemplateFunction("time", "cast(%s as time)")); registerNative(new NoArgSQLFunction("current_user", false)); registerNative(new NoArgSQLFunction("session_user", false)); registerNative(new NoArgSQLFunction("user", false)); registerNative(new NoArgSQLFunction("current_database", true)); registerNative(new NoArgSQLFunction("current_schema", true)); registerNative(new StandardSQLFunction("to_char")); registerNative(new StandardSQLFunction("to_date")); registerNative(new StandardSQLFunction("to_timestamp")); registerNative(new StandardSQLFunction("to_number")); registerNative(new StandardSQLFunction("bool_and")); registerNative(new StandardSQLFunction("bool_or")); registerNative(new StandardSQLFunction("bit_and")); registerNative(new StandardSQLFunction("bit_or")); registerNative(new StandardSQLFunction("extract")); registerNative(Func.replace); registerNative(Func.trim); registerNative(Func.ltrim); registerNative(Func.rtrim); registerNative(Func.lpad); registerNative(Func.rpad); registerCompatible(Func.concat, new VarArgsSQLFunction("", "||", "")); // Derby是没有concat函数的,要改写为相加 registerCompatible(Func.locate, new EmuLocateOnPostgres()); registerCompatible(Func.year, new EmuPostgresExtract("year")); registerCompatible(Func.month, new EmuPostgresExtract("month")); registerCompatible(Func.day, new EmuPostgresExtract("day")); registerCompatible(Func.hour, new EmuPostgresExtract("hour")); registerCompatible(Func.minute, new EmuPostgresExtract("minute")); registerCompatible(Func.second, new EmuPostgresExtract("second")); registerCompatible(Func.adddate, new EmuPostgresAddDate()); registerCompatible(Func.subdate, new EmuPostgresSubDate()); registerCompatible(Func.add_months, new TemplateFunction("add_months", "{fn timestampadd(SQL_TSI_MONTH,%2$s,%1$s)}")); registerCompatible(null, new TemplateFunction("timestamp", "%1$s::TIMESTAMP"), "timestamp"); registerCompatible(Func.timestampdiff, new EmuPostgreTimestampDiff());// 等PG的驱动完善了,可以改为EmuJDBCTimestampFunction registerCompatible(Func.timestampadd, new EmuJDBCTimestampFunction(Func.timestampadd, this)); registerCompatible(Func.datediff, new EmuDatediffByTimestampdiff());// 等PG的驱动完善了,可以改为EmuJDBCTimestampFunction registerCompatible(Func.adddate, new EmuDateAddSubByTimesatmpadd(Func.adddate)); registerCompatible(Func.subdate, new EmuDateAddSubByTimesatmpadd(Func.subdate)); registerCompatible(Func.decode, new EmuDecodeWithCase()); registerCompatible(Func.lengthb, new TemplateFunction("lengthb", "bit_length(%s)/8")); registerCompatible(Func.str, new CastFunction("str", "varchar")); setProperty(DbProperty.ADD_COLUMN, "ADD COLUMN"); setProperty(DbProperty.MODIFY_COLUMN, "ALTER COLUMN"); setProperty(DbProperty.DROP_COLUMN, "DROP COLUMN"); setProperty(DbProperty.CHECK_SQL, "select 1"); setProperty(DbProperty.SEQUENCE_FETCH, "select nextval('%s')"); setProperty(DbProperty.WRAP_FOR_KEYWORD, "\"\""); setProperty(DbProperty.GET_IDENTITY_FUNCTION, "SELECT currval('%tableName%_%columnName%_seq')"); typeNames.put(Types.BLOB, "bytea", Types.VARBINARY); typeNames.put(Types.CLOB, "text", 0); typeNames.put(Types.BOOLEAN, "boolean", 0,"bool"); typeNames.put(Types.TINYINT, "int2", 0); typeNames.put(Types.SMALLINT, "int2", 0); typeNames.put(Types.INTEGER, "int4", 0); typeNames.put(Types.BIGINT, "int8", 0); typeNames.put(Types.FLOAT, 6, "float4", 0); typeNames.put(Types.FLOAT, 15,"float8", Types.DOUBLE); typeNames.put(Types.FLOAT, 38,"numeric($p, $s)", Types.NUMERIC); typeNames.put(Types.DOUBLE, 15,"float8", 0); typeNames.put(Types.DOUBLE, 38,"numeric($p, $s)", Types.NUMERIC); typeNames.put(Types.NUMERIC, "numeric($p, $s)", 0); } public RDBMS getName() { return RDBMS.postgresql; } @Override public void accept(DbMetaData db) { super.accept(db); try { ensureUserFunction(this.functions.get("timestampdiff"), db); } catch (SQLException e) { LogUtil.exception("Initlize user function error.", e); } } public String getDriverClass(String url) { return "org.postgresql.Driver"; } @Override public String generateUrl(String host, int port, String pathOrName) { String url = String.format(JDBC_URL_FORMAT, host, (port <= 0 ? DEFAULT_PORT : port), pathOrName); if (ORMConfig.getInstance().isDebugMode()) { LogUtil.show(url); } return url; } @Override protected String getComment(AutoIncrement column, boolean flag) { /* * PG的自增主键后台其实是用类似于Oracle的实现完成的,后台会自动创建名为“: 表名_列命_seq这样一个sequence * 支持类似于Oracle的currval和nextval语法,具体写法如下:select * nextval('ca_asset_asset_id_seq'); select * currval('ca_asset_asset_id_seq'); 这为我们操作PG的Sequence提供了方便。 * 要注意这里的currval含义用法和Oracle一样 * ,不是获取sequence当前的值,而是返回当前sesssion中上一次获取过的seq值。 */ // if(sequenceMode){ // if(column.getSqlType()==Types.BIGINT){ // return flag?"int8 not null":"int8"; // }else{ // return flag?"int4 not null":"int4"; // } // }else{ if (column.getSqlType() == Types.BIGINT) { return flag ? "serial8 not null" : "serial8"; } else { return flag ? "serial4 not null" : "serial4"; } // } } public ColumnType getProprtMetaFromDbType(jef.database.meta.Column column) { if ("text".equals(column.getDataType())) { return new Clob(); } else if ("money".equals(column.getDataType())) { return new Varchar(column.getColumnSize() + 2); } else { return super.getProprtMetaFromDbType(column); } } // ,"jdbc:postgresql://localhost/soft" public void parseDbInfo(ConnectInfo connectInfo) { JefStringReader reader = new JefStringReader(connectInfo.getUrl()); reader.setIgnoreChars(' '); reader.consume("jdbc:postgresql:"); reader.omitChars('/'); String host = reader.readToken('/'); String name = reader.readToken(';', '?', '/'); connectInfo.setHost(host); connectInfo.setDbname(name); } @Override public void processIntervalExpression(BinaryExpression parent, Interval interval) { interval.toPostgresMode(); } @Override public void processIntervalExpression(Function func, Interval interval) { interval.toPostgresMode(); } @Override public long getColumnAutoIncreamentValue(AutoIncrementMapping mapping, JDBCTarget db) { String tableName = mapping.getMeta().getTableName(false).toLowerCase(); String seqname = tableName + "_" + mapping.lowerColumnName() + "_seq"; String sql = String.format("select nextval('%s')", seqname); if (ORMConfig.getInstance().isDebugMode()) { LogUtil.show(sql + " | " + db.getTransactionId()); } try { Statement st = db.createStatement(); ResultSet rs = null; try { rs = st.executeQuery(sql); rs.next(); return rs.getLong(1); } finally { DbUtils.close(rs); DbUtils.close(st); } } catch (SQLException e) { throw new PersistenceException(e); } } @Override public Statement wrap(Statement stmt, boolean isInJpaTx) throws SQLException { if (isInJpaTx && ORMConfig.getInstance().isKeepTxForPG()) { return new PGTxStatement(stmt); } else { return stmt; } } @Override public PreparedStatement wrap(PreparedStatement stmt, boolean isInJpaTx) throws SQLException { if (isInJpaTx && ORMConfig.getInstance().isKeepTxForPG()) { return new PGTxPreparedStatement(stmt); } else { return stmt; } } @Override public String getDefaultSchema() { return "public"; } @Override public String getSchema(String schema) { return schema != null ? schema.toLowerCase() : schema; } /** * 如果确定不需要这个特性支持,可以使用 db.keep.tx.for.postgresql=false来关闭 * * @author jiyi * */ private static final class PGTxPreparedStatement extends DelegatingPreparedStatement { private Connection conn; public PGTxPreparedStatement(PreparedStatement s) throws SQLException { super(s); this.conn = s.getConnection(); } @Override public ResultSet executeQuery() throws SQLException { Savepoint sp = conn.setSavepoint(); try { return ((PreparedStatement) _stmt).executeQuery(); } catch (SQLException e) { conn.rollback(sp); throw e; } finally { conn.releaseSavepoint(sp); } } @Override public int[] executeBatch() throws SQLException { Savepoint sp = conn.setSavepoint(); try { return _stmt.executeBatch(); } catch (SQLException e) { conn.rollback(sp); // PG在Batch模式下抛出的顶层错误是难以理解的。直接抛出nextException即可。 //throw e.getNextException() == null ? e : e.getNextException(); throw e; } finally { conn.releaseSavepoint(sp); } } @Override public int executeUpdate() throws SQLException { Savepoint sp = conn.setSavepoint(); try { return ((PreparedStatement) _stmt).executeUpdate(); } catch (SQLException e) { conn.rollback(sp); throw e; } finally { conn.releaseSavepoint(sp); } } @Override public boolean execute() throws SQLException { Savepoint sp = conn.setSavepoint(); try { return ((PreparedStatement) _stmt).execute(); } catch (SQLException e) { conn.rollback(sp); throw e; } finally { conn.releaseSavepoint(sp); } } } private static final class PGTxStatement extends DelegatingStatement { private Connection conn; public PGTxStatement(Statement s) throws SQLException { super(s); this.conn = s.getConnection(); } @Override public int[] executeBatch() throws SQLException { Savepoint sp = conn.setSavepoint(); try { return _stmt.executeBatch(); } catch (SQLException e) { conn.rollback(sp); throw e; } finally { conn.releaseSavepoint(sp); } } @Override public int executeUpdate(String sql) throws SQLException { Savepoint sp = conn.setSavepoint(); try { return _stmt.executeUpdate(sql); } catch (SQLException e) { conn.rollback(sp); throw e; } finally { conn.releaseSavepoint(sp); } } @Override public boolean execute(String sql) throws SQLException { Savepoint sp = conn.setSavepoint(); try { return _stmt.execute(sql); } catch (SQLException e) { conn.rollback(sp); throw e; } finally { conn.releaseSavepoint(sp); } } @Override public int executeUpdate(String sql, int autoGeneratedKeys) throws SQLException { Savepoint sp = conn.setSavepoint(); try { return _stmt.executeUpdate(sql, autoGeneratedKeys); } catch (SQLException e) { conn.rollback(sp); throw e; } finally { conn.releaseSavepoint(sp); } } @Override public int executeUpdate(String sql, int[] columnIndexes) throws SQLException { Savepoint sp = conn.setSavepoint(); try { return _stmt.executeUpdate(sql, columnIndexes); } catch (SQLException e) { conn.rollback(sp); throw e; } finally { conn.releaseSavepoint(sp); } } @Override public int executeUpdate(String sql, String[] columnNames) throws SQLException { Savepoint sp = conn.setSavepoint(); try { return _stmt.executeUpdate(sql, columnNames); } catch (SQLException e) { conn.rollback(sp); throw e; } finally { conn.releaseSavepoint(sp); } } @Override public boolean execute(String sql, int autoGeneratedKeys) throws SQLException { Savepoint sp = conn.setSavepoint(); try { return _stmt.execute(sql, autoGeneratedKeys); } catch (SQLException e) { conn.rollback(sp); throw e; } finally { conn.releaseSavepoint(sp); } } @Override public boolean execute(String sql, int[] columnIndexes) throws SQLException { Savepoint sp = conn.setSavepoint(); try { return _stmt.execute(sql, columnIndexes); } catch (SQLException e) { conn.rollback(sp); throw e; } finally { conn.releaseSavepoint(sp); } } @Override public boolean execute(String sql, String[] columnNames) throws SQLException { Savepoint sp = conn.setSavepoint(); try { return _stmt.execute(sql, columnNames); } catch (SQLException e) { conn.rollback(sp); throw e; } finally { conn.releaseSavepoint(sp); } } @Override public ResultSet executeQuery(String sql) throws SQLException { Savepoint sp = conn.setSavepoint(); try { return _stmt.executeQuery(sql); } catch (SQLException e) { conn.rollback(sp); throw e; } finally { conn.releaseSavepoint(sp); } } } @Override public boolean containKeyword(String name) { return keywords.contains(StringUtils.lowerCase(name)); } private final LimitHandler limit = new LimitOffsetLimitHandler(); @Override public LimitHandler getLimitHandler() { return limit; } @Override public ViolatedConstraintNameExtracter getViolatedConstraintNameExtracter() { return EXTRACTER; } private static ViolatedConstraintNameExtracter EXTRACTER = new TemplatedViolatedConstraintNameExtracter() { public String extractConstraintName(SQLException sqle) { try { int sqlState = Integer.valueOf(JDBCExceptionHelper.extractSqlState(sqle)).intValue(); switch (sqlState) { // CHECK VIOLATION case 23514: return sqle.getMessage(); // UNIQUE VIOLATION case 23505: return sqle.getMessage(); // FOREIGN KEY VIOLATION case 23503: return sqle.getMessage(); // NOT NULL VIOLATION case 23502: return sqle.getMessage(); // TODO: RESTRICT VIOLATION case 23001: return null; // ALL OTHER default: return null; } } catch (NumberFormatException nfe) { return null; } } }; }